arekit 0.25.0__py3-none-any.whl → 0.25.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. arekit/common/context/terms_mapper.py +5 -2
  2. arekit/common/data/input/providers/rows/samples.py +8 -12
  3. arekit/common/data/input/providers/sample/cropped.py +4 -3
  4. arekit/common/data/input/terms_mapper.py +4 -8
  5. arekit/common/data/storages/base.py +4 -18
  6. arekit/common/docs/entities_grouping.py +5 -3
  7. arekit/common/docs/parsed/base.py +3 -3
  8. arekit/common/docs/parsed/providers/base.py +3 -5
  9. arekit/common/docs/parsed/providers/entity_service.py +7 -28
  10. arekit/common/docs/parsed/providers/opinion_pairs.py +6 -6
  11. arekit/common/docs/parsed/providers/text_opinion_pairs.py +4 -4
  12. arekit/common/docs/parsed/service.py +2 -2
  13. arekit/common/docs/parser.py +3 -30
  14. arekit/common/model/labeling/single.py +7 -3
  15. arekit/common/opinions/annot/algo/pair_based.py +9 -5
  16. arekit/common/pipeline/base.py +0 -2
  17. arekit/common/pipeline/batching.py +0 -3
  18. arekit/common/pipeline/items/base.py +1 -1
  19. arekit/common/utils.py +11 -8
  20. arekit/contrib/bert/input/providers/cropped_sample.py +2 -5
  21. arekit/contrib/bert/terms/mapper.py +2 -2
  22. arekit/contrib/prompt/sample.py +2 -6
  23. arekit/contrib/utils/bert/samplers.py +4 -2
  24. arekit/contrib/utils/data/storages/jsonl_based.py +2 -1
  25. arekit/contrib/utils/data/storages/row_cache.py +2 -1
  26. arekit/contrib/utils/data/storages/sqlite_based.py +2 -1
  27. arekit/contrib/utils/pipelines/text_opinion/annot/algo_based.py +8 -5
  28. arekit/contrib/utils/pipelines/text_opinion/extraction.py +16 -8
  29. {arekit-0.25.0.dist-info → arekit-0.25.2.dist-info}/METADATA +10 -8
  30. {arekit-0.25.0.dist-info → arekit-0.25.2.dist-info}/RECORD +34 -115
  31. {arekit-0.25.0.dist-info → arekit-0.25.2.dist-info}/WHEEL +1 -1
  32. arekit/common/data/input/repositories/__init__.py +0 -0
  33. arekit/common/data/input/repositories/base.py +0 -68
  34. arekit/common/data/input/repositories/sample.py +0 -22
  35. arekit/common/data/views/__init__.py +0 -0
  36. arekit/common/data/views/samples.py +0 -26
  37. arekit/common/experiment/__init__.py +0 -0
  38. arekit/common/experiment/api/__init__.py +0 -0
  39. arekit/common/experiment/api/base_samples_io.py +0 -20
  40. arekit/common/experiment/data_type.py +0 -17
  41. arekit/common/service/__init__.py +0 -0
  42. arekit/common/service/sqlite.py +0 -36
  43. arekit/contrib/networks/__init__.py +0 -0
  44. arekit/contrib/networks/embedding.py +0 -149
  45. arekit/contrib/networks/embedding_io.py +0 -18
  46. arekit/contrib/networks/input/__init__.py +0 -0
  47. arekit/contrib/networks/input/const.py +0 -6
  48. arekit/contrib/networks/input/ctx_serialization.py +0 -28
  49. arekit/contrib/networks/input/embedding/__init__.py +0 -0
  50. arekit/contrib/networks/input/embedding/matrix.py +0 -29
  51. arekit/contrib/networks/input/embedding/offsets.py +0 -55
  52. arekit/contrib/networks/input/formatters/__init__.py +0 -0
  53. arekit/contrib/networks/input/formatters/pos_mapper.py +0 -22
  54. arekit/contrib/networks/input/providers/__init__.py +0 -0
  55. arekit/contrib/networks/input/providers/sample.py +0 -129
  56. arekit/contrib/networks/input/providers/term_connotation.py +0 -23
  57. arekit/contrib/networks/input/providers/text.py +0 -24
  58. arekit/contrib/networks/input/rows_parser.py +0 -47
  59. arekit/contrib/networks/input/term_types.py +0 -13
  60. arekit/contrib/networks/input/terms_mapping.py +0 -60
  61. arekit/contrib/networks/vectorizer.py +0 -6
  62. arekit/contrib/utils/data/readers/__init__.py +0 -0
  63. arekit/contrib/utils/data/readers/base.py +0 -7
  64. arekit/contrib/utils/data/readers/csv_pd.py +0 -38
  65. arekit/contrib/utils/data/readers/jsonl.py +0 -15
  66. arekit/contrib/utils/data/readers/sqlite.py +0 -14
  67. arekit/contrib/utils/data/service/__init__.py +0 -0
  68. arekit/contrib/utils/data/service/balance.py +0 -50
  69. arekit/contrib/utils/data/storages/pandas_based.py +0 -123
  70. arekit/contrib/utils/data/writers/csv_native.py +0 -63
  71. arekit/contrib/utils/data/writers/csv_pd.py +0 -40
  72. arekit/contrib/utils/data/writers/json_opennre.py +0 -132
  73. arekit/contrib/utils/data/writers/sqlite_native.py +0 -114
  74. arekit/contrib/utils/embeddings/__init__.py +0 -0
  75. arekit/contrib/utils/embeddings/rusvectores.py +0 -58
  76. arekit/contrib/utils/embeddings/tokens.py +0 -30
  77. arekit/contrib/utils/entities/formatters/str_display.py +0 -11
  78. arekit/contrib/utils/io_utils/embedding.py +0 -72
  79. arekit/contrib/utils/np_utils/__init__.py +0 -0
  80. arekit/contrib/utils/np_utils/embedding.py +0 -22
  81. arekit/contrib/utils/np_utils/npz_utils.py +0 -13
  82. arekit/contrib/utils/np_utils/vocab.py +0 -20
  83. arekit/contrib/utils/pipelines/items/sampling/__init__.py +0 -0
  84. arekit/contrib/utils/pipelines/items/sampling/base.py +0 -94
  85. arekit/contrib/utils/pipelines/items/sampling/networks.py +0 -55
  86. arekit/contrib/utils/pipelines/items/text/entities_default.py +0 -23
  87. arekit/contrib/utils/pipelines/items/text/frames_lemmatized.py +0 -36
  88. arekit/contrib/utils/pipelines/items/text/frames_negation.py +0 -33
  89. arekit/contrib/utils/pipelines/items/text/tokenizer.py +0 -105
  90. arekit/contrib/utils/pipelines/items/text/translator.py +0 -136
  91. arekit/contrib/utils/processing/__init__.py +0 -0
  92. arekit/contrib/utils/processing/languages/__init__.py +0 -0
  93. arekit/contrib/utils/processing/languages/mods.py +0 -12
  94. arekit/contrib/utils/processing/languages/pos.py +0 -23
  95. arekit/contrib/utils/processing/languages/ru/__init__.py +0 -0
  96. arekit/contrib/utils/processing/languages/ru/cases.py +0 -78
  97. arekit/contrib/utils/processing/languages/ru/constants.py +0 -6
  98. arekit/contrib/utils/processing/languages/ru/mods.py +0 -13
  99. arekit/contrib/utils/processing/languages/ru/number.py +0 -23
  100. arekit/contrib/utils/processing/languages/ru/pos_service.py +0 -36
  101. arekit/contrib/utils/processing/lemmatization/__init__.py +0 -0
  102. arekit/contrib/utils/processing/lemmatization/mystem.py +0 -51
  103. arekit/contrib/utils/processing/pos/__init__.py +0 -0
  104. arekit/contrib/utils/processing/pos/base.py +0 -12
  105. arekit/contrib/utils/processing/pos/mystem_wrap.py +0 -134
  106. arekit/contrib/utils/processing/pos/russian.py +0 -10
  107. arekit/contrib/utils/processing/text/__init__.py +0 -0
  108. arekit/contrib/utils/processing/text/tokens.py +0 -127
  109. arekit/contrib/utils/serializer.py +0 -42
  110. arekit/contrib/utils/vectorizers/__init__.py +0 -0
  111. arekit/contrib/utils/vectorizers/bpe.py +0 -93
  112. arekit/contrib/utils/vectorizers/random_norm.py +0 -39
  113. {arekit-0.25.0.data → arekit-0.25.2.data}/data/logo.png +0 -0
  114. {arekit-0.25.0.dist-info → arekit-0.25.2.dist-info}/LICENSE +0 -0
  115. {arekit-0.25.0.dist-info → arekit-0.25.2.dist-info}/top_level.txt +0 -0
@@ -1,47 +0,0 @@
1
- import arekit.contrib.networks.input.const as const
2
- from arekit.common.data.rows_fmt import process_indices_list
3
-
4
-
5
- def create_nn_column_formatters(no_value_func=lambda: None, args_sep=","):
6
- assert(callable(no_value_func))
7
-
8
- empty_list = []
9
-
10
- def str_to_list(value):
11
- return process_indices_list(value, no_value_func=no_value_func, args_sep=args_sep)
12
-
13
- def list_to_str(inds_iter):
14
- return args_sep.join([str(i) for i in inds_iter])
15
-
16
- return {
17
- const.FrameVariantIndices: {
18
- "writer": lambda value: list_to_str(value),
19
- "parser": lambda value: process_indices_list(value, no_value_func=no_value_func, args_sep=args_sep)
20
- if isinstance(value, str) else empty_list
21
- },
22
- const.FrameConnotations: {
23
- "writer": lambda value: list_to_str(value),
24
- "parser": lambda value: process_indices_list(value, no_value_func=no_value_func, args_sep=args_sep)
25
- if isinstance(value, str) else empty_list
26
- },
27
- const.SynonymObject: {
28
- "writer": lambda value: list_to_str(value),
29
- "parser": lambda value: process_indices_list(value, no_value_func=no_value_func, args_sep=args_sep)
30
- },
31
- const.SynonymSubject: {
32
- "writer": lambda value: list_to_str(value),
33
- "parser": lambda value: process_indices_list(value, no_value_func=no_value_func, args_sep=args_sep)
34
- },
35
- const.PosTags: {
36
- "writer": lambda value: list_to_str(value),
37
- "parser": lambda value: str_to_list(value)
38
- }
39
- }
40
-
41
-
42
- def create_nn_val_writer_fmt(fmt_type, args_sep=","):
43
- assert(isinstance(fmt_type, str))
44
- d = create_nn_column_formatters(args_sep=args_sep)
45
- for k, v in d.items():
46
- d[k] = v[fmt_type]
47
- return d
@@ -1,13 +0,0 @@
1
- class TermTypes(object):
2
- """ Types of input terms that may occur within the
3
- input sequence of the neural network moodels.
4
- """
5
-
6
- WORD = "word"
7
- ENTITY = "entity"
8
- FRAME = "frame"
9
- TOKEN = "token"
10
-
11
- @staticmethod
12
- def iter_types():
13
- return [TermTypes.WORD, TermTypes.ENTITY, TermTypes.FRAME, TermTypes.TOKEN]
@@ -1,60 +0,0 @@
1
- from arekit.common.data.input.terms_mapper import OpinionContainingTextTermsMapper
2
- from arekit.common.entities.base import Entity
3
- from arekit.common.frames.text_variant import TextFrameVariant
4
- from arekit.contrib.networks.input.term_types import TermTypes
5
-
6
-
7
- class VectorizedNetworkTermMapping(OpinionContainingTextTermsMapper):
8
- """ For every element returns: (word, embedded vector)
9
- """
10
-
11
- def __init__(self, string_entities_formatter, vectorizers):
12
- """string_emb_entity_formatter:
13
- Utilized in order to obtain embedding value from predefined_embeding for entities
14
- vectorizers:
15
- dict
16
- """
17
- assert(isinstance(vectorizers, dict))
18
-
19
- for term_type in TermTypes.iter_types():
20
- assert(term_type in vectorizers)
21
-
22
- super(VectorizedNetworkTermMapping, self).__init__(
23
- entity_formatter=string_entities_formatter)
24
-
25
- self.__vectorizers = vectorizers
26
-
27
- def map_term(self, term_type, term):
28
- """Universal term mapping method.
29
-
30
- Args:
31
- term_type (TermTypes): The type of term to map.
32
- term (str): The term to map.
33
-
34
- Returns:
35
- The mapped term.
36
- """
37
- return self.__vectorizers[term_type].create_term_embedding(term=term)
38
-
39
- def map_word(self, w_ind, word):
40
- return self.map_term(TermTypes.WORD, word)
41
-
42
- def map_text_frame_variant(self, fv_ind, text_frame_variant):
43
- assert(isinstance(text_frame_variant, TextFrameVariant))
44
- return self.map_term(TermTypes.FRAME, text_frame_variant.Variant.get_value())
45
-
46
- def map_token(self, t_ind, token):
47
- """ It assumes to be composed for all the supported types.
48
- """
49
- return self.map_term(TermTypes.TOKEN, token.get_token_value())
50
-
51
- def map_entity(self, e_ind, entity):
52
- assert(isinstance(entity, Entity))
53
-
54
- # Value extraction
55
- str_formatted_entity = super(VectorizedNetworkTermMapping, self).map_entity(
56
- e_ind=e_ind,
57
- entity=entity)
58
-
59
- # Vector extraction
60
- return self.map_term(TermTypes.ENTITY, str_formatted_entity)
@@ -1,6 +0,0 @@
1
- class BaseVectorizer(object):
2
- """ Custom API for vectorization
3
- """
4
-
5
- def create_term_embedding(self, term):
6
- raise NotImplementedError()
File without changes
@@ -1,7 +0,0 @@
1
- class BaseReader(object):
2
-
3
- def extension(self):
4
- raise NotImplementedError()
5
-
6
- def read(self, target):
7
- raise NotImplementedError()
@@ -1,38 +0,0 @@
1
- import importlib
2
-
3
- from arekit.contrib.utils.data.readers.base import BaseReader
4
- from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage
5
-
6
-
7
- class PandasCsvReader(BaseReader):
8
- """ Represents a CSV-based reader, implmented via pandas API.
9
- """
10
-
11
- def __init__(self, sep='\t', header='infer', compression='infer', encoding='utf-8', col_types=None,
12
- custom_extension=None):
13
- self.__sep = sep
14
- self.__compression = compression
15
- self.__encoding = encoding
16
- self.__header = header
17
- self.__custom_extension = custom_extension
18
-
19
- # Special assignation of types for certain columns.
20
- self.__col_types = col_types
21
- if self.__col_types is None:
22
- self.__col_types = dict()
23
-
24
- def extension(self):
25
- return ".tsv.gz" if self.__custom_extension is None else self.__custom_extension
26
-
27
- def __from_csv(self, filepath):
28
- pd = importlib.import_module("pandas")
29
- return pd.read_csv(filepath,
30
- sep=self.__sep,
31
- encoding=self.__encoding,
32
- compression=self.__compression,
33
- dtype=self.__col_types,
34
- header=self.__header)
35
-
36
- def read(self, target):
37
- df = self.__from_csv(filepath=target)
38
- return PandasBasedRowsStorage(df)
@@ -1,15 +0,0 @@
1
- from arekit.contrib.utils.data.readers.base import BaseReader
2
- from arekit.contrib.utils.data.storages.jsonl_based import JsonlBasedRowsStorage
3
-
4
-
5
- class JsonlReader(BaseReader):
6
-
7
- def extension(self):
8
- return ".jsonl"
9
-
10
- def read(self, target):
11
- rows = []
12
- with open(target, "r") as f:
13
- for line in f.readlines():
14
- rows.append(line)
15
- return JsonlBasedRowsStorage(rows)
@@ -1,14 +0,0 @@
1
- from arekit.contrib.utils.data.readers.base import BaseReader
2
- from arekit.contrib.utils.data.storages.sqlite_based import SQliteBasedRowsStorage
3
-
4
-
5
- class SQliteReader(BaseReader):
6
-
7
- def __init__(self, table_name):
8
- self.__table_name = table_name
9
-
10
- def extension(self):
11
- return ".sqlite"
12
-
13
- def read(self, target):
14
- return SQliteBasedRowsStorage(path=target, table_name=self.__table_name)
File without changes
@@ -1,50 +0,0 @@
1
- import gc
2
- import importlib
3
- from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage
4
-
5
-
6
- class PandasBasedStorageBalancing(object):
7
-
8
- @staticmethod
9
- def create_balanced_from(storage, column_name, free_origin=True):
10
- """ Performs oversampled balancing.
11
-
12
- Note: it is quite important to remove previously created storage
13
- in order to avoid memory leaking.
14
-
15
- storage: PandasBasedRowsStorage
16
- storage contents to be balanced.
17
-
18
- column_name: str
19
- column utilized for balancing.
20
-
21
- free_origin: bool
22
- indicates whether there is a need to release the resources
23
- utilized for the original storage.
24
- """
25
- assert(isinstance(storage, PandasBasedRowsStorage))
26
-
27
- original_df = storage.DataFrame
28
-
29
- max_size = original_df[column_name].value_counts().max()
30
-
31
- dframes = []
32
- for class_index, group in original_df.groupby(column_name):
33
- dframes.append(group.sample(max_size - len(group), replace=True))
34
-
35
- # Clear resources.
36
- pd = importlib.import_module("pandas")
37
- balanced_df = pd.concat(dframes + [original_df])
38
-
39
- # Removing temporary created dataframe.
40
- for df in dframes:
41
- del df
42
-
43
- # Marking the original dataframe as released
44
- # in terms of the allocated memory for it.
45
- if free_origin:
46
- storage.free()
47
-
48
- gc.collect()
49
-
50
- return PandasBasedRowsStorage(df=balanced_df)
@@ -1,123 +0,0 @@
1
- import importlib
2
-
3
- import numpy as np
4
-
5
- from arekit.common.data.input.providers.columns.base import BaseColumnsProvider
6
- from arekit.common.data.storages.base import BaseRowsStorage, logger
7
- from arekit.common.utils import progress_bar_iter
8
-
9
-
10
- class PandasBasedRowsStorage(BaseRowsStorage):
11
- """ Storage Kernel functions implementation,
12
- based on the pandas DataFrames.
13
- """
14
-
15
- def __init__(self, df=None):
16
- self._df = df
17
-
18
- @property
19
- def DataFrame(self):
20
- # TODO. Temporary hack, however this should be removed in future.
21
- return self._df
22
-
23
- @staticmethod
24
- def __create_empty(cols_with_types):
25
- """ cols_with_types: list of pairs ("name", dtype)
26
- """
27
- assert(isinstance(cols_with_types, list))
28
- data = np.empty(0, dtype=np.dtype(cols_with_types))
29
- pd = importlib.import_module("pandas")
30
- return pd.DataFrame(data)
31
-
32
- def __filter(self, column_name, value):
33
- return self._df[self._df[column_name] == value]
34
-
35
- @staticmethod
36
- def __iter_rows_core(df):
37
- for row_index, row in df.iterrows():
38
- yield row_index, row
39
-
40
- def __fill_with_blank_rows(self, row_id_column_name, rows_count):
41
- assert(isinstance(row_id_column_name, str))
42
- assert(isinstance(rows_count, int))
43
- self._df[row_id_column_name] = list(range(rows_count))
44
- self._df.set_index(row_id_column_name, inplace=True)
45
-
46
- # region protected methods
47
-
48
- def iter_column_names(self):
49
- return iter(self._df.columns)
50
-
51
- def iter_column_types(self):
52
- return iter(self._df.dtypes)
53
-
54
- def _set_row_value(self, row_ind, column, value):
55
- self._df.at[row_ind, column] = value
56
-
57
- def _iter_rows(self):
58
- for row_index, row in self.__iter_rows_core(self._df):
59
- yield row_index, row.to_dict()
60
-
61
- def _get_rows_count(self):
62
- return len(self._df)
63
-
64
- # endregion
65
-
66
- # region public methods
67
-
68
- def fill(self, iter_rows_func, columns_provider, row_handler=None, rows_count=None, desc=""):
69
- """ NOTE: We provide the rows counting which is required
70
- in order to know an expected amount of rows in advace
71
- due to the specifics of the pandas memory allocation
72
- for the DataFrames.
73
- The latter allows us avoid rows appending, which
74
- may significantly affects on performance once the size
75
- of DataFrame becomes relatively large.
76
- """
77
- assert(isinstance(columns_provider, BaseColumnsProvider))
78
-
79
- logger.info("Rows calculation process started. [Required by Pandas-Based storage kernel]")
80
- logged_rows_it = progress_bar_iter(
81
- iterable=iter_rows_func(True),
82
- desc="Calculating rows count ({reason})".format(reason=desc),
83
- unit="rows")
84
- rows_count = sum(1 for _ in logged_rows_it)
85
-
86
- logger.info("Filling with blank rows: {}".format(rows_count))
87
- self.__fill_with_blank_rows(row_id_column_name=columns_provider.ROW_ID,
88
- rows_count=rows_count)
89
- logger.info("Completed!")
90
-
91
- super(PandasBasedRowsStorage, self).fill(iter_rows_func=iter_rows_func,
92
- row_handler=row_handler,
93
- columns_provider=columns_provider,
94
- rows_count=rows_count)
95
-
96
- def get_row(self, row_index):
97
- return self._df.iloc[row_index]
98
-
99
- def get_cell(self, row_index, column_name):
100
- return self._df.iloc[row_index][column_name]
101
-
102
- def iter_column_values(self, column_name, dtype=None):
103
- values = self._df[column_name]
104
- if dtype is None:
105
- return values
106
- return values.astype(dtype)
107
-
108
- def find_by_value(self, column_name, value):
109
- return self.__filter(column_name=column_name, value=value)
110
-
111
- def init_empty(self, columns_provider):
112
- cols_with_types = columns_provider.get_columns_list_with_types()
113
- self._df = self.__create_empty(cols_with_types)
114
-
115
- def iter_shuffled(self):
116
- shuffled_df = self._df.sample(frac=1)
117
- return self.__iter_rows_core(shuffled_df)
118
-
119
- def free(self):
120
- del self._df
121
- super(PandasBasedRowsStorage, self).free()
122
-
123
- # endregion
@@ -1,63 +0,0 @@
1
- import csv
2
- import os
3
- from os.path import dirname
4
-
5
- from arekit.common.data.storages.base import BaseRowsStorage
6
- from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
7
- from arekit.contrib.utils.data.writers.base import BaseWriter
8
-
9
-
10
- class NativeCsvWriter(BaseWriter):
11
-
12
- def __init__(self, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL, header=True):
13
- self.__target_f = None
14
- self.__writer = None
15
- self.__create_writer_func = lambda f: csv.writer(
16
- f, delimiter=delimiter, quotechar=quotechar, quoting=quoting)
17
- self.__header = header
18
- self.__header_written = None
19
-
20
- def extension(self):
21
- return ".csv"
22
-
23
- @staticmethod
24
- def __iter_storage_column_names(storage):
25
- """ Iter only those columns that existed in storage.
26
- """
27
- for col_name in storage.iter_column_names():
28
- if col_name in storage.RowCache:
29
- yield col_name
30
-
31
- def open_target(self, target):
32
- os.makedirs(dirname(target), exist_ok=True)
33
- self.__target_f = open(target, "w")
34
- self.__writer = self.__create_writer_func(self.__target_f)
35
- self.__header_written = not self.__header
36
-
37
- def close_target(self):
38
- self.__target_f.close()
39
-
40
- def commit_line(self, storage):
41
- assert(isinstance(storage, RowCacheStorage))
42
- assert(self.__writer is not None)
43
-
44
- if not self.__header_written:
45
- self.__writer.writerow(list(self.__iter_storage_column_names(storage)))
46
- self.__header_written = True
47
-
48
- line_data = list(map(lambda col_name: storage.RowCache[col_name],
49
- self.__iter_storage_column_names(storage)))
50
- self.__writer.writerow(line_data)
51
-
52
- def write_all(self, storage, target):
53
- """ Writes all the `storage` rows
54
- into the `target` filepath, formatted as CSV.
55
- """
56
- assert(isinstance(storage, BaseRowsStorage))
57
-
58
- with open(target, "w") as f:
59
- writer = self.__create_writer_func(f)
60
- for _, row in storage:
61
- #content = [row[col_name] for col_name in storage.iter_column_names()]
62
- content = [v for v in row]
63
- writer.writerow(content)
@@ -1,40 +0,0 @@
1
- import logging
2
-
3
- from arekit.common.data.input.providers.columns.base import BaseColumnsProvider
4
- from arekit.common.utils import create_dir_if_not_exists
5
- from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage
6
- from arekit.contrib.utils.data.writers.base import BaseWriter
7
-
8
- logger = logging.getLogger(__name__)
9
- logging.basicConfig(level=logging.INFO)
10
-
11
-
12
- class PandasCsvWriter(BaseWriter):
13
-
14
- def __init__(self, write_header):
15
- super(PandasCsvWriter, self).__init__()
16
- self.__write_header = write_header
17
-
18
- def extension(self):
19
- return ".tsv.gz"
20
-
21
- def write_all(self, storage, target):
22
- assert(isinstance(storage, PandasBasedRowsStorage))
23
- assert(isinstance(target, str))
24
-
25
- create_dir_if_not_exists(target)
26
-
27
- # Temporary hack, remove it in future.
28
- df = storage.DataFrame
29
-
30
- logger.info("Saving... {length}: {filepath}".format(length=len(storage), filepath=target))
31
- df.to_csv(target,
32
- sep='\t',
33
- encoding='utf-8',
34
- columns=[c for c in df.columns if c != BaseColumnsProvider.ROW_ID],
35
- index=False,
36
- float_format="%.0f",
37
- compression='gzip',
38
- header=self.__write_header)
39
-
40
- logger.info("Saving completed!")
@@ -1,132 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- from os.path import dirname
5
-
6
- from arekit.common.data import const
7
- from arekit.common.data.storages.base import BaseRowsStorage
8
- from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
9
- from arekit.contrib.utils.data.writers.base import BaseWriter
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- class OpenNREJsonWriter(BaseWriter):
15
- """ This is a bag-based writer for the samples.
16
- Project page: https://github.com/thunlp/OpenNRE
17
-
18
- Every bag presented as follows:
19
- {
20
- 'text' or 'token': ...,
21
- 'h': {'pos': [start, end], 'id': ... },
22
- 't': {'pos': [start, end], 'id': ... }
23
- 'id': "id_of_the_text_opinion"
24
- }
25
-
26
- In terms of the linked opinions (i0, i1, etc.) we consider id of the first opinion in linkage.
27
- During the dataset reading stage via OpenNRE, these linkages automaticaly groups into bags.
28
- """
29
-
30
- def __init__(self, text_columns, encoding="utf-8", na_value="NA", keep_extra_columns=True,
31
- skip_extra_existed=True):
32
- """ text_columns: list
33
- column names that expected to be joined into a single (token) column.
34
- """
35
- assert(isinstance(text_columns, list))
36
- assert(isinstance(encoding, str))
37
- self.__text_columns = text_columns
38
- self.__encoding = encoding
39
- self.__target_f = None
40
- self.__keep_extra_columns = keep_extra_columns
41
- self.__na_value = na_value
42
- self.__skip_extra_existed = skip_extra_existed
43
-
44
- def extension(self):
45
- return ".jsonl"
46
-
47
- @staticmethod
48
- def __format_row(row, na_value, text_columns, keep_extra_columns, skip_extra_existed):
49
- """ Formatting that is compatible with the OpenNRE.
50
- """
51
- assert(isinstance(na_value, str))
52
-
53
- sample_id = row[const.ID]
54
- s_ind = int(row[const.S_IND])
55
- t_ind = int(row[const.T_IND])
56
- bag_id = str(row[const.OPINION_ID])
57
-
58
- # Gather tokens.
59
- tokens = []
60
- for text_col in text_columns:
61
- if text_col in row:
62
- tokens.extend(row[text_col].split())
63
-
64
- # Filtering JSON row.
65
- formatted_data = {
66
- "id": bag_id,
67
- "id_orig": sample_id,
68
- "token": tokens,
69
- "h": {"pos": [s_ind, s_ind + 1], "id": str(bag_id + "s")},
70
- "t": {"pos": [t_ind, t_ind + 1], "id": str(bag_id + "t")},
71
- "relation": str(int(row[const.LABEL_UINT])) if const.LABEL_UINT in row else na_value
72
- }
73
-
74
- # Register extra fields (optionally).
75
- if keep_extra_columns:
76
- for key, value in row.items():
77
- if key not in formatted_data and key not in text_columns:
78
- formatted_data[key] = value
79
- else:
80
- if not skip_extra_existed:
81
- raise Exception(f"key `{key}` is already exist in formatted data "
82
- f"or a part of the text columns list: {text_columns}")
83
-
84
- return formatted_data
85
-
86
- def open_target(self, target):
87
- os.makedirs(dirname(target), exist_ok=True)
88
- self.__target_f = open(target, "w")
89
- pass
90
-
91
- def close_target(self):
92
- self.__target_f.close()
93
-
94
- def commit_line(self, storage):
95
- assert(isinstance(storage, RowCacheStorage))
96
-
97
- # Collect existed columns.
98
- row_data = {}
99
- for col_name in storage.iter_column_names():
100
- if col_name not in storage.RowCache:
101
- continue
102
- row_data[col_name] = storage.RowCache[col_name]
103
-
104
- bag = self.__format_row(row_data, text_columns=self.__text_columns,
105
- keep_extra_columns=self.__keep_extra_columns,
106
- na_value=self.__na_value,
107
- skip_extra_existed=self.__skip_extra_existed)
108
-
109
- self.__write_bag(bag=bag, json_file=self.__target_f)
110
-
111
- @staticmethod
112
- def __write_bag(bag, json_file):
113
- assert(isinstance(bag, dict))
114
- json.dump(bag, json_file, separators=(",", ":"), ensure_ascii=False)
115
- json_file.write("\n")
116
-
117
- def write_all(self, storage, target):
118
- assert(isinstance(storage, BaseRowsStorage))
119
- assert(isinstance(target, str))
120
-
121
- logger.info("Saving... {rows}: {filepath}".format(rows=(len(storage)), filepath=target))
122
-
123
- os.makedirs(os.path.dirname(target), exist_ok=True)
124
- with open(target, "w", encoding=self.__encoding) as json_file:
125
- for row_index, row in storage:
126
- self.__write_bag(bag=self.__format_row(row, text_columns=self.__text_columns,
127
- keep_extra_columns=self.__keep_extra_columns,
128
- na_value=self.__na_value,
129
- skip_extra_existed=self.__skip_extra_existed),
130
- json_file=json_file)
131
-
132
- logger.info("Saving completed!")