bulk-chain 0.24.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bulk_chain/api.py ADDED
@@ -0,0 +1,79 @@
1
+ import os
2
+ from itertools import chain
3
+
4
+ from bulk_chain.core.llm_base import BaseLM
5
+ from bulk_chain.core.service_batch import BatchIterator, BatchService
6
+ from bulk_chain.core.service_data import DataService
7
+ from bulk_chain.core.service_dict import DictionaryService
8
+ from bulk_chain.core.service_json import JsonService
9
+ from bulk_chain.core.service_schema import SchemaService
10
+
11
+
12
+ INFER_MODES = {
13
+ "default": lambda llm, prompt, limit_prompt=None: llm.ask_core(
14
+ prompt[:limit_prompt] if limit_prompt is not None else prompt),
15
+ "batch": lambda llm, batch, limit_prompt=None: llm.ask_core(
16
+ DataService.limit_prompts(batch, limit=limit_prompt))
17
+ }
18
+
19
+
20
+ CWD = os.getcwd()
21
+
22
+
23
+ def _update_batch_content(c, batch, schema, infer_func):
24
+ assert (isinstance(batch, list))
25
+ assert (isinstance(c, str))
26
+
27
+ if c in schema.p2r:
28
+ for batch_item in batch:
29
+ batch_item[c] = DataService.get_prompt_text(prompt=batch_item[c]["prompt"], data_dict=batch_item)
30
+ if c in schema.r2p:
31
+ p_column = schema.r2p[c]
32
+ # This instruction takes a lot of time in a non-batching mode.
33
+ BatchService.handle_param_as_batch(batch=batch,
34
+ src_param=p_column,
35
+ tgt_param=c,
36
+ handle_func=lambda b: infer_func(b))
37
+
38
+
39
+ def _infer_batch(batch, schema, infer_func, cols=None):
40
+ assert (isinstance(batch, list))
41
+ assert (callable(infer_func))
42
+
43
+ if len(batch) == 0:
44
+ return batch
45
+
46
+ if cols is None:
47
+ first_item = batch[0]
48
+ cols = first_item.keys() if cols is None else cols
49
+
50
+ for c in cols:
51
+ _update_batch_content(c=c, batch=batch, schema=schema, infer_func=infer_func)
52
+
53
+ return batch
54
+
55
+
56
+ def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, limit_prompt=None):
57
+ """ This method represent Python API aimed at application of `llm` towards
58
+ iterator of input_dicts via cache_target that refers to the SQLite using
59
+ the given `schema`
60
+ """
61
+ assert (isinstance(llm, BaseLM))
62
+
63
+ # Quick initialization of the schema.
64
+ if isinstance(schema, str):
65
+ schema = JsonService.read(schema)
66
+ if isinstance(schema, dict):
67
+ schema = SchemaService(json_data=schema)
68
+
69
+ prompts_it = map(
70
+ lambda data: DictionaryService.custom_update(src_dict=data, other_dict=schema.cot_args),
71
+ input_dicts_it
72
+ )
73
+
74
+ content_it = (_infer_batch(batch=batch,
75
+ infer_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
76
+ schema=schema)
77
+ for batch in BatchIterator(prompts_it, batch_size=batch_size))
78
+
79
+ yield from content_it if return_batch else chain.from_iterable(content_it)
@@ -1,13 +1,52 @@
1
+ import logging
2
+ import time
3
+
1
4
  from bulk_chain.core.utils import format_model_name
2
5
 
3
6
 
4
7
  class BaseLM(object):
5
8
 
6
- def __init__(self, name):
9
+ def __init__(self, name=None, attempts=None, delay_sec=1, enable_log=True,
10
+ support_batching=False, **kwargs):
11
+
7
12
  self.__name = name
13
+ self.__attempts = 1 if attempts is None else attempts
14
+ self.__delay_sec = delay_sec
15
+ self.__support_batching = support_batching
16
+
17
+ if enable_log:
18
+ self.__logger = logging.getLogger(__name__)
19
+ logging.basicConfig(level=logging.INFO)
20
+
21
+ def ask_core(self, batch):
22
+
23
+ for i in range(self.__attempts):
24
+ try:
25
+ if self.__support_batching:
26
+ # Launch in batch mode.
27
+ content = self.ask(batch)
28
+ else:
29
+ # Launch in non-batch mode.
30
+ assert len(batch) == 1, "The LM does not support batching," \
31
+ f" while size of the content is {len(batch)} which is not equal 1. " \
32
+ f"Please enable batch-supporting or set required inference settings."
33
+ content = batch[0]
34
+
35
+ response = self.ask(content)
36
+
37
+ # Wrapping into batch the response in the case of non-batching mode.
38
+ return response if self.__support_batching else [response]
39
+
40
+ except Exception as e:
41
+ if self.__logger is not None:
42
+ self.__logger.info("Unable to infer the result. Try {} out of {}.".format(i, self.__attempts))
43
+ self.__logger.info(e)
44
+ time.sleep(self.__delay_sec)
45
+
46
+ raise Exception("Can't infer")
8
47
 
9
- def ask(self, prompt):
48
+ def ask(self, content):
10
49
  raise NotImplemented()
11
50
 
12
51
  def name(self):
13
- return format_model_name(self.__name)
52
+ return format_model_name(self.__name)
@@ -33,14 +33,33 @@ class CmdArgsService:
33
33
  yield __release()
34
34
 
35
35
  @staticmethod
36
- def partition_list(lst, sep):
36
+ def __find_suffix_ind(lst, idx_from, end_prefix):
37
+ for i in range(idx_from, len(lst)):
38
+ if lst[i].startswith(end_prefix):
39
+ return i
40
+ return len(lst)
41
+
42
+ @staticmethod
43
+ def extract_native_args(lst, end_prefix):
44
+ return lst[:CmdArgsService.__find_suffix_ind(lst, idx_from=0, end_prefix=end_prefix)]
45
+
46
+ @staticmethod
47
+ def find_grouped_args(lst, starts_with, end_prefix):
37
48
  """Slices a list in two, cutting on index matching "sep"
38
49
  """
39
- if sep in lst:
40
- idx = lst.index(sep)
41
- return (lst[:idx], lst[idx+1:])
42
- else:
43
- return (lst[:], None)
50
+
51
+ # Checking the presence of starts_with.
52
+ # We have to return empty content in the case of absence starts_with in the lst.
53
+ if starts_with not in lst:
54
+ return []
55
+
56
+ # Assigning start index.
57
+ idx_from = lst.index(starts_with) + 1
58
+
59
+ # Assigning end index.
60
+ idx_to = CmdArgsService.__find_suffix_ind(lst, idx_from=idx_from, end_prefix=end_prefix)
61
+
62
+ return lst[idx_from:idx_to]
44
63
 
45
64
  @staticmethod
46
65
  def args_to_dict(args):
@@ -0,0 +1,51 @@
1
+ class BatchService(object):
2
+
3
+ @staticmethod
4
+ def handle_param_as_batch(batch, src_param, tgt_param, handle_func):
5
+ assert (isinstance(batch, list))
6
+ assert (isinstance(src_param, str))
7
+ assert (callable(handle_func))
8
+
9
+ _batch = [item[src_param] for item in batch]
10
+
11
+ # Do handling for the batch.
12
+ _handled_batch = handle_func(_batch)
13
+ assert (isinstance(_handled_batch, list))
14
+
15
+ # Apply changes.
16
+ for i, item in enumerate(batch):
17
+ item[tgt_param] = _handled_batch[i]
18
+
19
+
20
+ class BatchIterator:
21
+
22
+ def __init__(self, data_iter, batch_size, end_value=None):
23
+ assert(isinstance(batch_size, int) and batch_size > 0)
24
+ assert(callable(end_value) or end_value is None)
25
+ self.__data_iter = data_iter
26
+ self.__index = 0
27
+ self.__batch_size = batch_size
28
+ self.__end_value = end_value
29
+
30
+ def __iter__(self):
31
+ return self
32
+
33
+ def __next__(self):
34
+ buffer = []
35
+ while True:
36
+ try:
37
+ data = next(self.__data_iter)
38
+ except StopIteration:
39
+ break
40
+ buffer.append(data)
41
+ if len(buffer) == self.__batch_size:
42
+ break
43
+
44
+ if len(buffer) > 0:
45
+ self.__index += 1
46
+ return buffer
47
+
48
+ if self.__end_value is None:
49
+ raise StopIteration
50
+ else:
51
+ return self.__end_value()
@@ -20,3 +20,7 @@ class DataService(object):
20
20
  field_names = list(parse_fields_func(prompt))
21
21
  return DataService.compose_prompt_text(
22
22
  prompt=prompt, data_dict=data_dict, field_names=field_names)
23
+
24
+ @staticmethod
25
+ def limit_prompts(prompts_list, limit=None):
26
+ return [p[:limit] if limit is not None else p for p in prompts_list]
@@ -0,0 +1,10 @@
1
+ class DictionaryService:
2
+
3
+ @staticmethod
4
+ def custom_update(src_dict, other_dict):
5
+ for k, v in other_dict.items():
6
+ if k in src_dict:
7
+ raise Exception(f"The key `{k}` is already defined in both dicts with values: "
8
+ f"`{src_dict[k]}` (src) and `{v}` (other)")
9
+ src_dict[k] = v
10
+ return src_dict
@@ -4,23 +4,7 @@ import json
4
4
  class JsonService(object):
5
5
 
6
6
  @staticmethod
7
- def read_data(src):
7
+ def read(src):
8
8
  assert (isinstance(src, str))
9
9
  with open(src, "r") as f:
10
- return json.load(f)
11
-
12
- @staticmethod
13
- def read_lines(src, row_id_key=None):
14
- assert (isinstance(src, str))
15
- with open(src, "r") as f:
16
- for line_ind, line in enumerate(f.readlines()):
17
- content = json.loads(line)
18
- if row_id_key is not None:
19
- content[row_id_key] = line_ind
20
- yield content
21
-
22
- @staticmethod
23
- def write_lines(target, data_it):
24
- with open(target, "w") as f:
25
- for item in data_it:
26
- f.write(f"{json.dumps(item, ensure_ascii=False)}\n")
10
+ return json.load(f)
@@ -4,9 +4,6 @@ from bulk_chain.core.llm_base import BaseLM
4
4
  from bulk_chain.core.service_data import DataService
5
5
  from bulk_chain.core.utils import iter_params
6
6
 
7
- logger = logging.getLogger(__name__)
8
- logging.basicConfig(level=logging.INFO)
9
-
10
7
 
11
8
  def pad_str(text, pad):
12
9
  return text.rjust(len(text) + pad, ' ')
@@ -27,9 +24,12 @@ def nice_output(text, width, pad=4, remove_new_line=False):
27
24
 
28
25
 
29
26
  def chat_with_lm(lm, chain=None, model_name=None):
30
- assert(isinstance(lm, BaseLM))
31
- assert(isinstance(chain, list))
32
- assert(isinstance(model_name, str) or model_name is None)
27
+ assert (isinstance(lm, BaseLM))
28
+ assert (isinstance(chain, list))
29
+ assert (isinstance(model_name, str) or model_name is None)
30
+
31
+ logger = logging.getLogger(__name__)
32
+ logging.basicConfig(level=logging.INFO)
33
33
 
34
34
  do_exit = False
35
35
  model_name = model_name if model_name is not None else "agent"
@@ -74,9 +74,9 @@ def chat_with_lm(lm, chain=None, model_name=None):
74
74
  logger.info(nice_output(actual_prompt, pad=pad*2, remove_new_line=True, width=80))
75
75
 
76
76
  # Response.
77
- response = lm.ask(actual_prompt)
77
+ response_batch = lm.ask_core(batch=[actual_prompt])
78
78
  logger.info(pad_str(f"{model_name} (resp)->", pad=pad))
79
- logger.info(nice_output(response, pad=pad*2, remove_new_line=False, width=80))
79
+ logger.info(nice_output(response_batch[0], pad=pad * 2, remove_new_line=False, width=80))
80
80
 
81
81
  # Collecting the answer for the next turn.
82
- data_dict[prompt_args["out"]] = response
82
+ data_dict[prompt_args["out"]] = response_batch[0]
@@ -2,12 +2,11 @@ class SchemaService(object):
2
2
 
3
3
  def __init__(self, json_data):
4
4
  self.src = json_data
5
- self.name = self.src["name"]
6
5
  self.r2p, self.p2r, self.cot_args, self.chain = SchemaService.__init_schema(prompts=json_data["schema"])
7
6
 
8
7
  @classmethod
9
8
  def from_prompt(cls, prompt):
10
- prompt_schema = {"name": "prompt", "schema": [{"prompt": prompt, "out": "response", "in": "prompt"}]}
9
+ prompt_schema = {"schema": [{"prompt": prompt, "out": "response", "in": "prompt"}]}
11
10
  return cls(prompt_schema)
12
11
 
13
12
  @staticmethod
bulk_chain/infer.py CHANGED
@@ -1,17 +1,18 @@
1
+ from os.path import join, basename
2
+
1
3
  import argparse
2
4
  import logging
3
- import os
4
5
  import sys
5
6
 
7
+ from source_iter.service_csv import CsvService
8
+ from source_iter.service_jsonl import JsonlService
9
+ from source_iter.service_sqlite import SQLite3Service
6
10
  from tqdm import tqdm
7
11
 
8
- from os.path import join, basename
9
-
12
+ from bulk_chain.api import INFER_MODES, _infer_batch, CWD
10
13
  from bulk_chain.core.llm_base import BaseLM
11
- from bulk_chain.core.provider_sqlite import SQLiteProvider
12
14
  from bulk_chain.core.service_args import CmdArgsService
13
- from bulk_chain.core.service_csv import CsvService
14
- from bulk_chain.core.service_data import DataService
15
+ from bulk_chain.core.service_dict import DictionaryService
15
16
  from bulk_chain.core.service_json import JsonService
16
17
  from bulk_chain.core.service_llm import chat_with_lm
17
18
  from bulk_chain.core.service_schema import SchemaService
@@ -21,7 +22,16 @@ logger = logging.getLogger(__name__)
21
22
  logging.basicConfig(level=logging.INFO)
22
23
 
23
24
 
24
- CWD = os.getcwd()
25
+ WRITER_PROVIDERS = {
26
+ "sqlite": lambda filepath, table_name, data_it, infer_data_func, **kwargs: SQLite3Service.write(
27
+ data_it=data_it, target=filepath, table_name=table_name, data2col_func=infer_data_func,
28
+ skip_existed=True, **kwargs)
29
+ }
30
+
31
+
32
+ READER_PROVIDERS = {
33
+ "sqlite": lambda filepath, table_name: SQLite3Service.read(filepath, table=table_name)
34
+ }
25
35
 
26
36
 
27
37
  def init_llm(**model_kwargs):
@@ -44,59 +54,44 @@ def init_llm(**model_kwargs):
44
54
  return llm, llm_model_name
45
55
 
46
56
 
47
- def init_schema(json_filepath):
48
- return SchemaService(json_data=JsonService.read_data(json_filepath))
49
-
50
-
51
- def iter_content(input_dicts_iter, llm, schema, cache_target, cache_table, id_column_name):
52
- """ This method represent Python API aimed at application of `llm` towards
53
- iterator of input_dicts via cache_target that refers to the SQLite using
54
- the given `schema`
55
- """
57
+ def iter_content_cached(input_dicts_it, llm, schema, cache_target, limit_prompt=None, **cache_kwargs):
56
58
  assert (isinstance(llm, BaseLM))
57
- assert (isinstance(schema, SchemaService))
58
59
  assert (isinstance(cache_target, str))
59
- assert (isinstance(cache_table, str))
60
60
 
61
- infer_modes = {
62
- "default": lambda prompt: llm.ask(prompt[:args.limit_prompt] if args.limit_prompt is not None else prompt)
63
- }
61
+ # Quick initialization of the schema.
62
+ if isinstance(schema, str):
63
+ schema = JsonService.read(schema)
64
+ if isinstance(schema, dict):
65
+ schema = SchemaService(json_data=schema)
64
66
 
65
- def optional_update_data_records(c, data):
66
- assert (isinstance(c, str))
67
+ # Iterator of the queries.
68
+ prompts_it = map(
69
+ lambda data: DictionaryService.custom_update(src_dict=data, other_dict=schema.cot_args),
70
+ input_dicts_it
71
+ )
67
72
 
68
- if c in schema.p2r:
69
- data[c] = DataService.get_prompt_text(prompt=data[c]["prompt"], data_dict=data)
70
- if c in schema.r2p:
71
- p_column = schema.r2p[c]
72
- # This instruction takes a lot of time in a non-batching mode.
73
- data[c] = infer_modes["default"](data[p_column])
73
+ # Parse target.
74
+ cache_filepath, _, cache_table = parse_filepath(filepath=cache_target)
74
75
 
75
- return data[c]
76
+ # Perform caching first.
77
+ WRITER_PROVIDERS["sqlite"](
78
+ filepath=cache_filepath, table_name=cache_table,
79
+ data_it=tqdm(prompts_it, desc="Iter content"),
80
+ infer_data_func=lambda c, prompt: _infer_batch(
81
+ batch=[prompt], cols=[c],
82
+ infer_func=lambda batch: INFER_MODES["default"](llm, batch, limit_prompt),
83
+ schema=schema)[0][c],
84
+ **cache_kwargs)
76
85
 
77
- cache_providers = {
78
- "sqlite": lambda filepath, table_name, data_it: SQLiteProvider.write_auto(
79
- data_it=data_it, target=filepath,
80
- data2col_func=optional_update_data_records,
81
- table_name=handle_table_name(table_name if table_name is not None else "contents"),
82
- id_column_name=id_column_name)
83
- }
84
-
85
- # We optionally wrap into limiter.
86
- queries_it = optional_limit_iter(
87
- it_data=map(lambda data: data.update(schema.cot_args) or data, input_dicts_iter),
88
- limit=args.limit)
89
-
90
- # Provide data caching.
91
- cache_providers["sqlite"](cache_target, table_name=tgt_meta, data_it=tqdm(queries_it, desc="Iter content"))
92
-
93
- return SQLiteProvider.iter_rows(cache_target, table=cache_table)
86
+ # Then retrieve data.
87
+ return READER_PROVIDERS["sqlite"](filepath=cache_filepath, table_name=cache_table)
94
88
 
95
89
 
96
90
  if __name__ == '__main__':
97
91
 
98
92
  parser = argparse.ArgumentParser(description="Infer Instruct LLM inference based on CoT schema")
99
93
  parser.add_argument('--adapter', dest='adapter', type=str, default=None)
94
+ parser.add_argument('--attempts', dest='attempts', type=int, default=None)
100
95
  parser.add_argument('--id-col', dest='id_col', type=str, default="uid")
101
96
  parser.add_argument('--src', dest='src', type=str, default=None)
102
97
  parser.add_argument('--schema', dest='schema', type=str, default=None,
@@ -108,34 +103,52 @@ if __name__ == '__main__':
108
103
  parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
109
104
  help="Optional trimming prompt by the specified amount of characters.")
110
105
 
111
- native_args, model_args = CmdArgsService.partition_list(lst=sys.argv, sep="%%")
112
-
106
+ # Extract native arguments.
107
+ native_args = CmdArgsService.extract_native_args(sys.argv, end_prefix="%%")
113
108
  args = parser.parse_args(args=native_args[1:])
114
109
 
115
- # Initialize Large Language Model.
116
- model_args_dict = CmdArgsService.args_to_dict(model_args)
110
+ # Extract csv-related arguments.
111
+ csv_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%csv", end_prefix="%%")
112
+ csv_args_dict = CmdArgsService.args_to_dict(csv_args)
113
+
114
+ # Extract model-related arguments and Initialize Large Language Model.
115
+ model_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%m", end_prefix="%%")
116
+ model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": args.attempts}
117
117
  llm, llm_model_name = init_llm(**model_args_dict)
118
118
 
119
119
  # Setup schema.
120
- schema = init_schema(args.schema)
120
+ schema = SchemaService(json_data=JsonService.read(args.schema))
121
+ schema_name = schema.src.get("name", None)
121
122
  if schema is not None:
122
- logger.info(f"Using schema: {schema.name}")
123
+ logger.info(f"Using schema: {schema_name}")
123
124
 
124
125
  input_providers = {
125
126
  None: lambda _: chat_with_lm(llm, chain=schema.chain, model_name=llm_model_name),
126
- "csv": lambda filepath: CsvService.read(target=filepath, row_id_key=args.id_col,
127
+ "csv": lambda filepath: CsvService.read(src=filepath, row_id_key=args.id_col,
128
+ as_dict=True, skip_header=True,
129
+ delimiter=csv_args_dict.get("delimiter", ","),
130
+ escapechar=csv_args_dict.get("escapechar", None)),
131
+ "tsv": lambda filepath: CsvService.read(src=filepath, row_id_key=args.id_col,
127
132
  as_dict=True, skip_header=True,
128
- delimiter=model_args_dict.get("delimiter", "\t"),
129
- escapechar=model_args_dict.get("escapechar", None)),
130
- "jsonl": lambda filepath: JsonService.read_lines(src=filepath, row_id_key=args.id_col)
133
+ delimiter=csv_args_dict.get("delimiter", "\t"),
134
+ escapechar=csv_args_dict.get("escapechar", None)),
135
+ "jsonl": lambda filepath: JsonlService.read(src=filepath, row_id_key=args.id_col)
131
136
  }
132
137
 
133
138
  output_providers = {
134
- "csv": lambda filepath, data_it, header:
135
- CsvService.write_handled(target=filepath, data_it=data_it, header=header, data2col_func=lambda v: list(v)),
139
+ "csv": lambda filepath, data_it, header: CsvService.write(target=filepath,
140
+ data_it=data_it, header=header,
141
+ delimiter=csv_args_dict.get("delimiter", ","),
142
+ escapechar=csv_args_dict.get("escapechar", None),
143
+ it_type=None),
144
+ "tsv": lambda filepath, data_it, header: CsvService.write(target=filepath,
145
+ data_it=data_it, header=header,
146
+ delimiter=csv_args_dict.get("delimiter", "\t"),
147
+ escapechar=csv_args_dict.get("escapechar", None),
148
+ it_type=None),
136
149
  "jsonl": lambda filepath, data_it, header:
137
- JsonService.write_lines(target=filepath,
138
- data_it=map(lambda item: {key:item[i] for i, key in enumerate(header)}, data_it))
150
+ JsonlService.write(target=filepath,
151
+ data_it=map(lambda item: {key: item[i] for i, key in enumerate(header)}, data_it))
139
152
  }
140
153
 
141
154
  # Setup output.
@@ -150,24 +163,29 @@ if __name__ == '__main__':
150
163
  input_providers[src_ext](None)
151
164
  exit(0)
152
165
 
166
+ def default_output_file_template(ext):
167
+ # This is a default template for output files to be generated.
168
+ return "".join(["_".join([join(CWD, basename(src_filepath)), llm.name(), schema_name]), ext])
169
+
153
170
  # Setup cache target as well as the related table.
154
- cache_target = "".join(["_".join([join(CWD, basename(src_filepath)), llm.name(), schema.name]), f".sqlite"]) \
155
- if tgt_filepath is None else tgt_filepath
171
+ cache_filepath = default_output_file_template(".sqlite") if tgt_filepath is None else tgt_filepath
156
172
  cache_table = handle_table_name(tgt_meta if tgt_meta is not None else "contents")
157
173
 
158
- data_it = iter_content(input_dicts_iter=input_providers[src_ext](src_filepath),
159
- schema=schema,
160
- llm=llm,
161
- id_column_name=args.id_col,
162
- cache_target=cache_target,
163
- cache_table=cache_table)
174
+ # This is a content that we extracted via input provider.
175
+ it_data = input_providers[src_ext](src_filepath)
176
+
177
+ data_it = iter_content_cached(input_dicts_it=optional_limit_iter(it_data=it_data, limit=args.limit),
178
+ limit_prompt=args.limit_prompt,
179
+ schema=schema,
180
+ llm=llm,
181
+ id_column_name=args.id_col,
182
+ cache_target=":".join([cache_filepath, cache_table]))
164
183
 
165
184
  # Setup output target
166
185
  tgt_ext = src_ext if tgt_ext is None else tgt_ext
167
- output_target = "".join(["_".join([join(CWD, basename(src_filepath)), llm.name(), schema.name]), f".{tgt_ext}"]) \
168
- if tgt_filepath is None else tgt_filepath
186
+ output_target = default_output_file_template(f".{tgt_ext}") if tgt_filepath is None else tgt_filepath
169
187
 
170
188
  # Perform output writing process.
171
189
  output_providers[tgt_ext](filepath=output_target,
172
190
  data_it=data_it,
173
- header=SQLiteProvider.get_columns(target=cache_target, table=cache_table))
191
+ header=SQLite3Service.read_columns(target=cache_filepath, table=cache_table))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bulk_chain
3
- Version: 0.24.1
3
+ Version: 0.25.0
4
4
  Summary: A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, ensuring reliable results for bulk input requests.
5
5
  Home-page: https://github.com/nicolay-r/bulk-chain
6
6
  Author: Nicolay Rusnachenko
@@ -15,32 +15,42 @@ Classifier: Topic :: Text Processing :: Linguistic
15
15
  Requires-Python: >=3.6
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
- Requires-Dist: tqdm
19
18
 
20
- # bulk-chain 0.24.1
19
+ # bulk-chain 0.25.0
21
20
  ![](https://img.shields.io/badge/Python-3.9-brightgreen.svg)
22
21
  [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
23
22
  [![twitter](https://img.shields.io/twitter/url/https/shields.io.svg?style=social)](https://x.com/nicolayr_/status/1847969224636961033)
23
+ [![PyPI downloads](https://img.shields.io/pypi/dm/bulk-chain.svg)](https://pypistats.org/packages/bulk-chain)
24
24
 
25
25
  <p align="center">
26
26
  <img src="logo.png"/>
27
27
  </p>
28
28
 
29
- A lightweight, no-strings-attached **[Chain-of-Thought](https://arxiv.org/abs/2201.11903) framework** for your LLM, ensuring reliable results for bulk input requests stored in `CSV` / `JSONL` / `sqlite`.
30
- It allows applying series of prompts formed into `schema` (See [related section](#chain-of-thought-schema))
29
+ A lightweight, no-strings-attached **framework** for your LLM that allows applying [Chain-of-Thought](https://arxiv.org/abs/2201.11903) prompt `schema` (See [related section](#chain-of-thought-schema)) towards a massive textual collections.
31
30
 
32
- ### Features
31
+ ### Main Features
33
32
  * ✅ **No-strings**: you're free to LLM dependencies and flexible `venv` customization.
34
- * ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
35
- * ✅ **Progress caching**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
36
33
  * ✅ **Support schemas descriptions** for Chain-of-Thought concept.
34
+ * ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
35
+
36
+ ### Extra Features
37
+ * ✅ **Progress caching [for remote LLMs]**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
38
+
37
39
 
38
40
  # Installation
39
41
 
42
+ From PyPI:
43
+
40
44
  ```bash
41
45
  pip install bulk-chain
42
46
  ```
43
47
 
48
+ or latest version from here:
49
+
50
+ ```bash
51
+ pip install git+https://github.com/nicolay-r/bulk-chain@master
52
+ ```
53
+
44
54
  ## Chain-of-Thought Schema
45
55
 
46
56
  To declare Chain-of-Though (CoT) schema, this project exploits `JSON` format.
@@ -63,35 +73,37 @@ Below, is an example on how to declare your own schema:
63
73
  }
64
74
  ```
65
75
 
66
- Another templates are available [here](/ext/schema/thor_cot_schema.json).
76
+ Another templates are available [here](/ext/schema/).
67
77
 
68
78
  # Usage
69
79
 
70
- Just **three** simple steps:
80
+ Preliminary steps:
71
81
 
72
- 1. Define your [CoT Schema](#chain-of-thought-schema), or fetch it as shown below:
73
- ```bash
74
- !wget https://raw.githubusercontent.com/nicolay-r/bulk-chain/refs/heads/master/ext/schema/default.json
75
- ```
76
- 2. Fetch or write your own **model** or pick the one [preset here](/ext/):
77
- ```bash
78
- !wget https://raw.githubusercontent.com/nicolay-r/bulk-chain/refs/heads/master/ext/flan_t5.py
79
- ```
82
+ 1. Define your [schema](#chain-of-thought-schema) ([Example for Sentiment Analysis](/ext/schema/thor_cot_schema.json)))
83
+ 2. Wrap or pick **LLM model** from the [list of presets](/ext/).
84
+
85
+ ## API
86
+
87
+ Please take a look at the [**related Wiki page**](https://github.com/nicolay-r/bulk-chain/wiki)
88
+
89
+ ## Shell
90
+
91
+ > **NOTE:** You have to install `source-iter` package
80
92
 
81
- 3. Launch inference in (chat mode):
82
93
  ```bash
83
- !python -m bulk_chain.infer \
84
- --schema "default.json" \
85
- --adapter "dynamic:flan_t5.py:FlanT5" \
86
- %% \
87
- --device "cpu" \
94
+ python3 -m bulk_chain.infer \
95
+ --src "<PATH-TO-YOUR-CSV-or-JSONL>" \
96
+ --schema "ext/schema/default.json" \
97
+ --adapter "dynamic:ext/replicate.py:Replicate" \
98
+ %%m \
99
+ --api_token "<REPLICATE-API-TOKEN>" \
88
100
  --temp 0.1
89
101
  ```
90
102
 
91
103
  # Embed your LLM
92
104
 
93
105
  All you have to do is to implement `BaseLM` class, that includes:
94
- * `__init__` -- for initialization;
106
+ * `__init__` -- for setting up *batching mode support* and (optional) *model name*;
95
107
  * `ask(prompt)` -- infer your model with the given `prompt`.
96
108
 
97
109
  See examples with models [here](/ext).
@@ -0,0 +1,18 @@
1
+ bulk_chain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ bulk_chain/api.py,sha256=08i2tgFa_CCA0obC_Yr3rURI6MkuXYKgmuZaLcs4NLk,2807
3
+ bulk_chain/infer.py,sha256=oWtBf2itZeM3fD-_QAzABKUMbsl4BqvHmW21TUTr880,9110
4
+ bulk_chain/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ bulk_chain/core/llm_base.py,sha256=uX_uibm5y8STfMKNYL64EeF8UowfJGwCD_t-uftHoJE,1849
6
+ bulk_chain/core/service_args.py,sha256=x-QHaKLD1d6qaJkD4lNwx7640ku9-6Uyr3mooB_6kLc,1981
7
+ bulk_chain/core/service_batch.py,sha256=yQr6fbQd4ifQBGMhZMrQQeZpXtDchMKMGJi8XPG7thc,1430
8
+ bulk_chain/core/service_data.py,sha256=ZjJDtd1jrQm9hRCXMqe4CT_qF2XDbWBE1lVibP7tAWo,942
9
+ bulk_chain/core/service_dict.py,sha256=lAghLU-3V3xYGv5BTA327Qcw8UJYmgQRMFdggzlrUgo,383
10
+ bulk_chain/core/service_json.py,sha256=6o1xM_8c9QEjH9Q3qEmJylU9nahfRXhUd5sFF2dGJwo,182
11
+ bulk_chain/core/service_llm.py,sha256=1xbFW5OQY2ckKwIDZjsgNtnxKDp2wDjKKwyNS_yMU2s,2776
12
+ bulk_chain/core/service_schema.py,sha256=KIP4n0Tz2h1i7SIMGhgAhoiCgUFXOT1rzMt38yACS2U,1154
13
+ bulk_chain/core/utils.py,sha256=UV6Cefaw7yZiYblsCr-s9LsbcI83xe7eESBvha9A2Og,2784
14
+ bulk_chain-0.25.0.dist-info/LICENSE,sha256=VF9SjNpwwSSFEY_eP_8A1ocDCrbwfjI1pZexXdCkOwo,1076
15
+ bulk_chain-0.25.0.dist-info/METADATA,sha256=-Ky6ZekXHUCBByhSTgDYgMpC64ew8lGmQ7-I9dKsv6U,3874
16
+ bulk_chain-0.25.0.dist-info/WHEEL,sha256=pL8R0wFFS65tNSRnaOVrsw9EOkOqxLrlUPenUYnJKNo,91
17
+ bulk_chain-0.25.0.dist-info/top_level.txt,sha256=Hxq_wyH-GDXKBaA63UfBIiMJO2eCHJG5EOrXDphpeB4,11
18
+ bulk_chain-0.25.0.dist-info/RECORD,,
@@ -1,78 +0,0 @@
1
- import sqlite3
2
-
3
-
4
- class SQLiteProvider(object):
5
-
6
- @staticmethod
7
- def __create_table(table_name, columns, id_column_name,
8
- id_column_type, sqlite3_column_types, cur):
9
-
10
- # Provide the ID column.
11
- sqlite3_column_types = [id_column_type] + sqlite3_column_types
12
-
13
- # Compose the whole columns list.
14
- content = ", ".join([f"[{item[0]}] {item[1]}" for item in zip(columns, sqlite3_column_types)])
15
- cur.execute(f"CREATE TABLE IF NOT EXISTS {table_name}({content})")
16
- cur.execute(f"CREATE INDEX IF NOT EXISTS [{id_column_name}] ON {table_name}([{id_column_name}])")
17
-
18
- @staticmethod
19
- def write_auto(data_it, target, data2col_func, table_name, id_column_name="id",
20
- id_column_type="INTEGER"):
21
- """ NOTE: data_it is an iterator of dictionaries.
22
- This implementation automatically creates the table and
23
- """
24
- with sqlite3.connect(target) as con:
25
- cur = con.cursor()
26
-
27
- columns = None
28
- for data in data_it:
29
- assert(isinstance(data, dict))
30
-
31
- # Extracting columns from data.
32
- row_columns = list(data.keys())
33
- assert(id_column_name in row_columns)
34
-
35
- # Optionally create table.
36
- if columns is None:
37
-
38
- # Setup list of columns.
39
- columns = row_columns
40
- # Place ID column first.
41
- columns.insert(0, columns.pop(columns.index(id_column_name)))
42
-
43
- SQLiteProvider.__create_table(
44
- columns=columns, table_name=table_name, cur=cur,
45
- id_column_name=id_column_name, id_column_type=id_column_type,
46
- sqlite3_column_types=["TEXT"] * len(columns))
47
-
48
- # Check that each rows satisfies criteria of the first row.
49
- [Exception(f"{column} is expected to be in row!") for column in row_columns if column not in columns]
50
-
51
- uid = data[id_column_name]
52
- r = cur.execute(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE [{id_column_name}]='{uid}');")
53
- ans = r.fetchone()[0]
54
- if ans == 1:
55
- continue
56
-
57
- params = ", ".join(tuple(['?'] * (len(columns))))
58
- row_columns_str = ", ".join([f"[{col}]" for col in row_columns])
59
- cur.execute(f"INSERT INTO {table_name}({row_columns_str}) VALUES ({params})",
60
- [data2col_func(c, data) for c in row_columns])
61
- con.commit()
62
-
63
- cur.close()
64
-
65
- @staticmethod
66
- def iter_rows(target, table="content"):
67
- with sqlite3.connect(target) as conn:
68
- cursor = conn.cursor()
69
- cursor.execute(f"SELECT * FROM {table}")
70
- for row in cursor:
71
- yield row
72
-
73
- @staticmethod
74
- def get_columns(target, table="content"):
75
- with sqlite3.connect(target) as conn:
76
- cursor = conn.cursor()
77
- cursor.execute(f"PRAGMA table_info({table})")
78
- return [row[1] for row in cursor.fetchall()]
@@ -1,57 +0,0 @@
1
- import csv
2
- import logging
3
-
4
- logger = logging.getLogger(__name__)
5
- logging.basicConfig(level=logging.INFO)
6
-
7
-
8
- class CsvService:
9
-
10
- @staticmethod
11
- def write(target, lines_it):
12
- f = open(target, "w")
13
- logger.info(f"Saving: {target}")
14
- w = csv.writer(f, delimiter="\t", quotechar='"', quoting=csv.QUOTE_MINIMAL)
15
- for content in lines_it:
16
- w.writerow(content)
17
-
18
- @staticmethod
19
- def write_handled(target, data_it, data2col_func, header):
20
-
21
- def __it():
22
- yield header
23
- for data in data_it:
24
- content = data2col_func(data)
25
- assert(len(content) == len(header))
26
- yield content
27
-
28
- CsvService.write(target, lines_it=__it())
29
-
30
- @staticmethod
31
- def read(target, skip_header=False, cols=None, as_dict=False, row_id_key=None, **csv_kwargs):
32
- assert (isinstance(row_id_key, str) or row_id_key is None)
33
- assert (isinstance(cols, list) or cols is None)
34
-
35
- header = None
36
- with open(target, newline='\n') as f:
37
- for row_id, row in enumerate(csv.reader(f, **csv_kwargs)):
38
- if skip_header and row_id == 0:
39
- header = ([row_id_key] if row_id_key is not None else []) + row
40
- continue
41
-
42
- # Determine the content we wish to return.
43
- if cols is None:
44
- content = row
45
- else:
46
- row_d = {header[col_ind]: value for col_ind, value in enumerate(row)}
47
- content = [row_d[col_name] for col_name in cols]
48
-
49
- content = ([row_id-1] if row_id_key is not None else []) + content
50
-
51
- # Optionally attach row_id to the content.
52
- if as_dict:
53
- assert (header is not None)
54
- assert (len(content) == len(header))
55
- yield {k: v for k, v in zip(header, content)}
56
- else:
57
- yield content
@@ -1,17 +0,0 @@
1
- bulk_chain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- bulk_chain/infer.py,sha256=hD9GJEp6P9PZRBSUCIxK8DaDjsX-oiq8VCe0rAD2EPs,7366
3
- bulk_chain/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- bulk_chain/core/llm_base.py,sha256=5js2RJLpNS5t-De-xTpZCbLMgbz3F_b9tU_CtXhy02I,259
5
- bulk_chain/core/provider_sqlite.py,sha256=rNUvBt3aGa6Uv4a9RItyMgBZPnFbBdNjnt0Gw81lM3I,3171
6
- bulk_chain/core/service_args.py,sha256=Qr3rHsAB8wnajB-DbU-GjiEpRZFP4D6s1lVTpLkPPX4,1294
7
- bulk_chain/core/service_csv.py,sha256=-m8tNN9aIqRfJa4sPUX8ZUDP4W0fgnnOR3_0PapepDY,1984
8
- bulk_chain/core/service_data.py,sha256=18gQwSCTEsI7XFukq8AE5lDJX_QQRpasaH69g6EddV0,797
9
- bulk_chain/core/service_json.py,sha256=alYqTQbBjAcCh7anSTOZs1CLJbiWrLPpzLcoADstD0Q,743
10
- bulk_chain/core/service_llm.py,sha256=tYgMphJkXunhxdrThdfI4eM8qQTCZfEM1kabbReVjuQ,2726
11
- bulk_chain/core/service_schema.py,sha256=JVhOv2YP5VEtiwOq_zgCzhS2uF_BOATAgg6fmKRf2NQ,1209
12
- bulk_chain/core/utils.py,sha256=UV6Cefaw7yZiYblsCr-s9LsbcI83xe7eESBvha9A2Og,2784
13
- bulk_chain-0.24.1.dist-info/LICENSE,sha256=VF9SjNpwwSSFEY_eP_8A1ocDCrbwfjI1pZexXdCkOwo,1076
14
- bulk_chain-0.24.1.dist-info/METADATA,sha256=g5_Sr1pfa8v5lRs0sd7Ldch-uLiV_KfdDXaTHSen-R4,3649
15
- bulk_chain-0.24.1.dist-info/WHEEL,sha256=pL8R0wFFS65tNSRnaOVrsw9EOkOqxLrlUPenUYnJKNo,91
16
- bulk_chain-0.24.1.dist-info/top_level.txt,sha256=Hxq_wyH-GDXKBaA63UfBIiMJO2eCHJG5EOrXDphpeB4,11
17
- bulk_chain-0.24.1.dist-info/RECORD,,