bulk-chain 0.25.2__py3-none-any.whl → 0.25.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bulk_chain/api.py CHANGED
@@ -1,3 +1,4 @@
1
+ import collections
1
2
  import os
2
3
  from itertools import chain
3
4
 
@@ -9,9 +10,8 @@ from bulk_chain.core.service_json import JsonService
9
10
  from bulk_chain.core.service_schema import SchemaService
10
11
  from bulk_chain.core.utils import dynamic_init, find_by_prefix
11
12
 
13
+
12
14
  INFER_MODES = {
13
- "default": lambda llm, prompt, limit_prompt=None: llm.ask_core(
14
- prompt[:limit_prompt] if limit_prompt is not None else prompt),
15
15
  "batch": lambda llm, batch, limit_prompt=None: llm.ask_core(
16
16
  DataService.limit_prompts(batch, limit=limit_prompt))
17
17
  }
@@ -20,40 +20,67 @@ INFER_MODES = {
20
20
  CWD = os.getcwd()
21
21
 
22
22
 
23
- def _update_batch_content(c, batch, schema, infer_func):
23
+ def _handle_entry(entry, entry_info=None, **kwargs):
24
+
25
+ if isinstance(entry, str):
26
+ kwargs.get("callback_str_func", lambda *_: None)(entry, entry_info)
27
+ return entry
28
+ elif isinstance(entry, collections.abc.Iterable):
29
+ chunks = []
30
+ h = kwargs.get("callback_stream_func", lambda *_: None)
31
+
32
+ h(None, entry_info | {"action": "start"})
33
+
34
+ for chunk in map(lambda item: str(item), entry):
35
+ chunks.append(chunk)
36
+ h(chunk, entry_info)
37
+
38
+ h(None, entry_info | {"action": "end"})
39
+
40
+ return "".join(chunks)
41
+
42
+ raise Exception(f"Non supported type `{type(entry)}` for handling output from batch")
43
+
44
+
45
+ def _update_batch_content(c, batch, schema, **kwargs):
24
46
  assert (isinstance(batch, list))
25
47
  assert (isinstance(c, str))
26
48
 
27
49
  if c in schema.p2r:
28
50
  for batch_item in batch:
29
- batch_item[c] = DataService.get_prompt_text(prompt=batch_item[c]["prompt"], data_dict=batch_item)
51
+ batch_item[c] = DataService.get_prompt_text(
52
+ prompt=batch_item[c]["prompt"],
53
+ data_dict=batch_item,
54
+ handle_missed_func=kwargs["handle_missed_value_func"])
30
55
  if c in schema.r2p:
31
56
  p_column = schema.r2p[c]
32
57
  # This instruction takes a lot of time in a non-batching mode.
33
- BatchService.handle_param_as_batch(batch=batch,
34
- src_param=p_column,
35
- tgt_param=c,
36
- handle_func=lambda b: infer_func(b))
58
+ BatchService.handle_param_as_batch(
59
+ batch=batch,
60
+ src_param=p_column,
61
+ tgt_param=c,
62
+ handle_batch_func=lambda b: kwargs["handle_batch_func"](b),
63
+ handle_entry_func=lambda entry, info: _handle_entry(entry=entry, entry_info=info, **kwargs)
64
+ )
37
65
 
38
66
 
39
- def _infer_batch(batch, schema, infer_func, cols=None):
67
+ def _infer_batch(batch, schema, cols=None, **kwargs):
40
68
  assert (isinstance(batch, list))
41
- assert (callable(infer_func))
42
69
 
43
70
  if len(batch) == 0:
44
71
  return batch
45
72
 
46
73
  if cols is None:
47
74
  first_item = batch[0]
48
- cols = first_item.keys() if cols is None else cols
75
+ cols = list(first_item.keys()) if cols is None else cols
49
76
 
50
77
  for c in cols:
51
- _update_batch_content(c=c, batch=batch, schema=schema, infer_func=infer_func)
78
+ _update_batch_content(c=c, batch=batch, schema=schema, **kwargs)
52
79
 
53
80
  return batch
54
81
 
55
82
 
56
- def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, limit_prompt=None):
83
+ def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, limit_prompt=None, **kwargs):
57
84
  """ This method represent Python API aimed at application of `llm` towards
58
85
  iterator of input_dicts via cache_target that refers to the SQLite using
59
86
  the given `schema`
@@ -72,8 +99,9 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, l
72
99
  )
73
100
 
74
101
  content_it = (_infer_batch(batch=batch,
75
- infer_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
76
- schema=schema)
102
+ handle_batch_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
103
+ schema=schema,
104
+ **kwargs)
77
105
  for batch in BatchIterator(prompts_it, batch_size=batch_size))
78
106
 
79
107
  yield from content_it if return_batch else chain.from_iterable(content_it)
@@ -82,6 +110,7 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, l
82
110
  def init_llm(adapter, **model_kwargs):
83
111
  """ This method perform dynamic initialization of LLM from third-party resource.
84
112
  """
113
+ assert (isinstance(adapter, str))
85
114
 
86
115
  # List of the Supported models and their API wrappers.
87
116
  models_preset = {
@@ -0,0 +1,127 @@
1
+ import sqlite3
2
+
3
+
4
+ class SQLite3Service(object):
5
+
6
+ @staticmethod
7
+ def __create_table(table_name, columns, id_column_name,
8
+ id_column_type, cur, sqlite3_column_types=None):
9
+
10
+ # Setting up default column types.
11
+ if sqlite3_column_types is None:
12
+ types_count = len(columns) if id_column_name in columns else len(columns) - 1
13
+ sqlite3_column_types = ["TEXT"] * types_count
14
+
15
+ # Provide the ID column.
16
+ sqlite3_column_types = [id_column_type] + sqlite3_column_types
17
+
18
+ # Compose the whole columns list.
19
+ content = ", ".join([f"[{item[0]}] {item[1]}" for item in zip(columns, sqlite3_column_types)])
20
+ cur.execute(f"CREATE TABLE IF NOT EXISTS {table_name}({content})")
21
+ cur.execute(f"CREATE INDEX IF NOT EXISTS [{id_column_name}] ON {table_name}([{id_column_name}])")
22
+
23
+ @staticmethod
24
+ def __it_row_lists(cursor):
25
+ for row in cursor:
26
+ yield row
27
+
28
+ @staticmethod
29
+ def create_table_if_not_exist(**kwargs):
30
+ return SQLite3Service.__create_table(**kwargs)
31
+
32
+ @staticmethod
33
+ def entry_exist(table_name, target, id_column_name, id_value, **connect_kwargs) -> bool:
34
+ with sqlite3.connect(target, **connect_kwargs) as con:
35
+ cursor = con.cursor()
36
+
37
+ # Check table existance.
38
+ query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
39
+ cursor.execute(query, (table_name,))
40
+ if cursor.fetchone() is None:
41
+ return False
42
+
43
+ # Check element.
44
+ r = cursor.execute(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE [{id_column_name}]='{id_value}');")
45
+ ans = r.fetchone()[0]
46
+ return ans == 1
47
+
48
+ @staticmethod
49
+ def write(data_it, target, table_name, columns=None, id_column_name="id", data2col_func=None,
50
+ id_column_type="INTEGER", sqlite3_column_types=None, it_type='dict',
51
+ create_table_if_not_exist=True, skip_existed=True, **connect_kwargs):
52
+
53
+ need_set_column_id = True
54
+ need_initialize_columns = columns is None
55
+
56
+ # Setup default columns.
57
+ columns = [] if columns is None else columns
58
+
59
+ with sqlite3.connect(target, **connect_kwargs) as con:
60
+ cur = con.cursor()
61
+
62
+ for content in data_it:
63
+
64
+ if it_type == 'dict':
65
+ # Extracting columns from data.
66
+ data = content
67
+ uid = data[id_column_name]
68
+ row_columns = list(data.keys())
69
+ row_params_func = lambda: [data2col_func(c, data) if data2col_func is not None else data[c]
70
+ for c in row_columns]
71
+ # Append columns if needed.
72
+ if need_initialize_columns:
73
+ columns = list(row_columns)
74
+ elif it_type is None:
75
+ # Setup row columns.
76
+ uid, data = content
77
+ row_columns = columns
78
+ row_params_func = lambda: [uid] + data
79
+ else:
80
+ raise Exception(f"it_type {it_type} does not supported!")
81
+
82
+ if need_set_column_id:
83
+ # Register ID column.
84
+ if id_column_name not in columns:
85
+ columns.append(id_column_name)
86
+ # Place ID column first.
87
+ columns.insert(0, columns.pop(columns.index(id_column_name)))
88
+ need_set_column_id = False
89
+
90
+ if create_table_if_not_exist:
91
+ SQLite3Service.__create_table(
92
+ columns=columns, table_name=table_name, cur=cur,
93
+ id_column_name=id_column_name, id_column_type=id_column_type,
94
+ sqlite3_column_types=sqlite3_column_types)
95
+
96
+ # Check that each rows satisfies criteria of the first row.
97
+ [Exception(f"{column} is expected to be in row!") for column in row_columns if column not in columns]
98
+
99
+ if skip_existed:
100
+ r = cur.execute(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE [{id_column_name}]='{uid}');")
101
+ ans = r.fetchone()[0]
102
+
103
+ if ans == 1:
104
+ continue
105
+
106
+ params = ", ".join(tuple(['?'] * (len(columns))))
107
+ row_columns_str = ", ".join([f"[{col}]" for col in row_columns])
108
+ content_list = row_params_func()
109
+ cur.execute(f"INSERT INTO {table_name}({row_columns_str}) VALUES ({params})", content_list)
110
+ con.commit()
111
+
112
+ cur.close()
113
+
114
+ @staticmethod
115
+ def read(src, table="content", **connect_kwargs):
116
+ with sqlite3.connect(src, **connect_kwargs) as conn:
117
+ cursor = conn.cursor()
118
+ cursor.execute(f"SELECT * FROM {table}")
119
+ for record_list in SQLite3Service.__it_row_lists(cursor):
120
+ yield record_list
121
+
122
+ @staticmethod
123
+ def read_columns(target, table="content", **connect_kwargs):
124
+ with sqlite3.connect(target, **connect_kwargs) as conn:
125
+ cursor = conn.cursor()
126
+ cursor.execute(f"PRAGMA table_info({table})")
127
+ return [row[1] for row in cursor.fetchall()]
@@ -1,31 +1,32 @@
1
1
  class BatchService(object):
2
2
 
3
3
  @staticmethod
4
- def handle_param_as_batch(batch, src_param, tgt_param, handle_func):
4
+ def handle_param_as_batch(batch, src_param, tgt_param, handle_batch_func, handle_entry_func):
5
5
  assert (isinstance(batch, list))
6
6
  assert (isinstance(src_param, str))
7
- assert (callable(handle_func))
7
+ assert (callable(handle_batch_func))
8
8
 
9
9
  _batch = [item[src_param] for item in batch]
10
10
 
11
11
  # Do handling for the batch.
12
- _handled_batch = handle_func(_batch)
12
+ _handled_batch = handle_batch_func(_batch)
13
13
  assert (isinstance(_handled_batch, list))
14
14
 
15
15
  # Apply changes.
16
16
  for i, item in enumerate(batch):
17
- item[tgt_param] = _handled_batch[i]
17
+ item[tgt_param] = handle_entry_func(entry=_handled_batch[i], info={"ind": i, "param": tgt_param})
18
18
 
19
19
 
20
20
  class BatchIterator:
21
21
 
22
- def __init__(self, data_iter, batch_size, end_value=None):
22
+ def __init__(self, data_iter, batch_size, end_value=None, filter_func=None):
23
23
  assert(isinstance(batch_size, int) and batch_size > 0)
24
24
  assert(callable(end_value) or end_value is None)
25
25
  self.__data_iter = data_iter
26
26
  self.__index = 0
27
27
  self.__batch_size = batch_size
28
28
  self.__end_value = end_value
29
+ self.__filter_func = (lambda _: True) if filter_func is None else filter_func
29
30
 
30
31
  def __iter__(self):
31
32
  return self
@@ -37,7 +38,8 @@ class BatchIterator:
37
38
  data = next(self.__data_iter)
38
39
  except StopIteration:
39
40
  break
40
- buffer.append(data)
41
+ if self.__filter_func(data):
42
+ buffer.append(data)
41
43
  if len(buffer) == self.__batch_size:
42
44
  break
43
45
 
@@ -4,8 +4,8 @@ from bulk_chain.core.utils import iter_params
4
4
  class DataService(object):
5
5
 
6
6
  @staticmethod
7
- def compose_prompt_text(prompt, data_dict, field_names):
8
- assert(isinstance(data_dict, dict))
7
+ def __compose_prompt_text(prompt, data_dict, field_names):
8
+ assert (isinstance(data_dict, dict))
9
9
  fmt_d = {col_name: data_dict[col_name] for col_name in field_names}
10
10
 
11
11
  # Guarantee that items has correct type.
@@ -16,10 +16,14 @@ class DataService(object):
16
16
  return prompt.format(**fmt_d)
17
17
 
18
18
  @staticmethod
19
- def get_prompt_text(prompt, data_dict, parse_fields_func=iter_params):
19
+ def get_prompt_text(prompt, data_dict, parse_fields_func=iter_params, handle_missed_func=None):
20
20
  field_names = list(parse_fields_func(prompt))
21
- return DataService.compose_prompt_text(
22
- prompt=prompt, data_dict=data_dict, field_names=field_names)
21
+
22
+ for col_name in field_names:
23
+ if col_name not in data_dict:
24
+ data_dict[col_name] = handle_missed_func(col_name)
25
+
26
+ return DataService.__compose_prompt_text(prompt=prompt, data_dict=data_dict, field_names=field_names)
23
27
 
24
28
  @staticmethod
25
29
  def limit_prompts(prompts_list, limit=None):
@@ -1,6 +1,5 @@
1
+ from bulk_chain.api import iter_content
1
2
  from bulk_chain.core.llm_base import BaseLM
2
- from bulk_chain.core.service_data import DataService
3
- from bulk_chain.core.utils import iter_params
4
3
  from bulk_chain.core.utils_logger import StreamedLogger
5
4
 
6
5
 
@@ -13,82 +12,57 @@ def nice_output(text, remove_new_line=False):
13
12
  return short_text
14
13
 
15
14
 
16
- def chat_with_lm(lm, preset_dict=None, chain=None, model_name=None, pad=0):
15
+ def chat_with_lm(lm, preset_dict=None, schema=None, model_name=None, pad=0):
17
16
  assert (isinstance(lm, BaseLM))
18
- assert (isinstance(chain, list))
19
17
  assert (isinstance(model_name, str) or model_name is None)
20
18
 
21
19
  preset_dict = {} if preset_dict is None else preset_dict
22
20
 
23
21
  streamed_logger = StreamedLogger(__name__)
24
-
25
22
  do_exit = False
26
23
  model_name = model_name if model_name is not None else "agent"
27
24
 
28
25
  while not do_exit:
29
26
 
30
- streamed_logger.info("----------------")
31
- streamed_logger.info("\n")
32
-
33
27
  # Launching the CoT engine loop.
34
28
  data_dict = {} | preset_dict
35
- for chain_ind, prompt_args in enumerate(chain):
36
-
37
- # Processing the prompt.
38
- prompt = prompt_args["prompt"]
39
-
40
- # Filling necessary parameters.
41
- user_informed = False
42
- field_names = list(iter_params(prompt))
43
- for ind, f_name in enumerate(field_names):
44
-
45
- if f_name in data_dict:
46
- continue
47
-
48
- user_input = input(f"Enter your prompt for `{f_name}` ({ind+1}/{len(field_names)}) "
49
- f"(or 'exit' to quit): ")
50
- user_informed = True
51
-
52
- if user_input.lower() == 'exit':
53
- do_exit = True
54
- break
55
-
56
- data_dict[f_name] = user_input
57
-
58
- if do_exit:
59
- break
60
-
61
- # In the case of the initial interaction with the chain.
62
- # we make sure that aware user for starting interaction.
63
- if chain_ind == 0 and not user_informed:
64
- user_input = input(f"Enter to continue (or 'exit' to quit) ...")
65
- if user_input.lower() == 'exit':
66
- do_exit = True
67
-
68
- # Finally asking LLM.
69
- DataService.compose_prompt_text(prompt=prompt, data_dict=data_dict, field_names=field_names)
70
- actual_prompt = DataService.get_prompt_text(prompt=prompt, data_dict=data_dict)
71
-
72
- # Returning meta information, passed to LLM.
73
- streamed_logger.info(pad_str(f"{model_name} (ask [{chain_ind+1}/{len(chain)}]) ->", pad=pad))
74
- streamed_logger.info("\n")
75
- streamed_logger.info(nice_output(actual_prompt, remove_new_line=True))
76
- streamed_logger.info("\n\n")
77
-
78
- # Response.
79
- response = lm.ask_core(batch=[actual_prompt])[0]
80
- streamed_logger.info(pad_str(f"{model_name} (resp [{chain_ind+1}/{len(chain)}])->", pad=pad))
81
- streamed_logger.info("\n")
82
- if isinstance(response, str):
83
- streamed_logger.info(nice_output(response, remove_new_line=False))
84
- buffer = [response]
85
- else:
86
- buffer = []
87
- for chunk in response:
88
- streamed_logger.info(chunk)
89
- buffer.append(str(chunk))
90
29
 
30
+ def callback_str_func(entry, info):
31
+ streamed_logger.info(pad_str(f"{model_name} ({info['param']})->\n", pad=pad))
32
+ streamed_logger.info(nice_output(entry, remove_new_line=False))
91
33
  streamed_logger.info("\n\n")
92
34
 
93
- # Collecting the answer for the next turn.
94
- data_dict[prompt_args["out"]] = "".join(buffer)
35
+ def handle_missed_value(col_name):
36
+ user_input = input(f"Enter your prompt for `{col_name}`"
37
+ f"(or 'exit' to quit): ")
38
+
39
+ if user_input.lower() == 'exit':
40
+ exit(0)
41
+
42
+ return user_input
43
+
44
+ def callback_stream_func(entry, info):
45
+ if entry is None and info["action"] == "start":
46
+ streamed_logger.info(pad_str(f"{model_name} ({info['param']})->\n", pad=pad))
47
+ return
48
+ if entry is None and info["action"] == "end":
49
+ streamed_logger.info("\n\n")
50
+ return
51
+
52
+ streamed_logger.info(entry)
53
+
54
+ content_it = iter_content(
55
+ input_dicts_it=[data_dict],
56
+ llm=lm,
57
+ schema=schema,
58
+ batch_size=1,
59
+ return_batch=True,
60
+ handle_missed_value_func=handle_missed_value,
61
+ callback_str_func=callback_str_func,
62
+ callback_stream_func=callback_stream_func,
63
+ )
64
+
65
+ for _ in content_it:
66
+ user_input = input(f"Enter to continue (or 'exit' to quit) ...\n")
67
+ if user_input.lower() == 'exit':
68
+ do_exit = True
bulk_chain/core/utils.py CHANGED
@@ -2,6 +2,7 @@ import importlib
2
2
  import logging
3
3
  import sys
4
4
  from collections import Counter
5
+ from os.path import dirname, join, basename
5
6
 
6
7
  logger = logging.getLogger(__name__)
7
8
  logging.basicConfig(level=logging.INFO)
@@ -82,13 +83,24 @@ def auto_import(name, is_class=False):
82
83
 
83
84
 
84
85
  def dynamic_init(class_dir, class_filepath, class_name=None):
85
- sys.path.append(class_dir)
86
+
87
+ # Registering path.
88
+ target = join(class_dir, dirname(class_filepath))
89
+ logger.info(f"Adding sys path for `{target}`")
90
+ sys.path.insert(1, target)
86
91
  class_path_list = class_filepath.split('/')
87
- class_path_list[-1] = '.'.join(class_path_list[-1].split('.')[:-1])
92
+
93
+ # Composing proper class name.
94
+ class_filename = basename(class_path_list[-1])
95
+ if class_filename.endswith(".py"):
96
+ class_filename = class_filename[:-len(".py")]
97
+
98
+ # Loading library.
88
99
  class_name = class_path_list[-1].title() if class_name is None else class_name
89
- class_path = ".".join(class_path_list + [class_name])
100
+ class_path = ".".join([class_filename, class_name])
90
101
  logger.info(f"Dynamic loading for the file and class `{class_path}`")
91
102
  cls = auto_import(class_path, is_class=False)
103
+
92
104
  return cls
93
105
 
94
106
 
bulk_chain/demo.py CHANGED
@@ -81,4 +81,4 @@ if __name__ == '__main__':
81
81
  preset_dict[key] = value
82
82
 
83
83
  # Launch Demo.
84
- chat_with_lm(llm, preset_dict=preset_dict, chain=schema.chain, model_name=llm_model_name)
84
+ chat_with_lm(llm, preset_dict=preset_dict, schema=schema, model_name=llm_model_name)
bulk_chain/infer.py CHANGED
@@ -1,3 +1,4 @@
1
+ from itertools import chain
1
2
  from os.path import join, basename
2
3
 
3
4
  import argparse
@@ -6,12 +7,13 @@ import sys
6
7
 
7
8
  from source_iter.service_csv import CsvService
8
9
  from source_iter.service_jsonl import JsonlService
9
- from source_iter.service_sqlite import SQLite3Service
10
10
  from tqdm import tqdm
11
11
 
12
12
  from bulk_chain.api import INFER_MODES, _infer_batch, CWD, init_llm
13
13
  from bulk_chain.core.llm_base import BaseLM
14
+ from bulk_chain.core.provider_sqlite import SQLite3Service
14
15
  from bulk_chain.core.service_args import CmdArgsService
16
+ from bulk_chain.core.service_batch import BatchIterator
15
17
  from bulk_chain.core.service_dict import DictionaryService
16
18
  from bulk_chain.core.service_json import JsonService
17
19
  from bulk_chain.core.service_schema import SchemaService
@@ -21,9 +23,8 @@ logger = logging.getLogger(__name__)
21
23
  logging.basicConfig(level=logging.INFO)
22
24
 
23
25
  WRITER_PROVIDERS = {
24
- "sqlite": lambda filepath, table_name, data_it, infer_data_func, **kwargs: SQLite3Service.write(
25
- data_it=data_it, target=filepath, table_name=table_name, data2col_func=infer_data_func,
26
- skip_existed=True, **kwargs)
26
+ "sqlite": lambda filepath, table_name, data_it, **kwargs: SQLite3Service.write(
27
+ data_it=data_it, target=filepath, table_name=table_name, skip_existed=True, **kwargs)
27
28
  }
28
29
 
29
30
  READER_PROVIDERS = {
@@ -31,7 +32,19 @@ READER_PROVIDERS = {
31
32
  }
32
33
 
33
34
 
34
- def iter_content_cached(input_dicts_it, llm, schema, cache_target, limit_prompt=None, **cache_kwargs):
35
+ def infer_batch(batch, columns=None, **kwargs):
36
+ assert (len(batch) > 0)
37
+ # TODO. Support proper selection of columns.
38
+ cols = batch[0].keys() if columns is None else columns
39
+ return _infer_batch(batch=batch, cols=cols, **kwargs)
40
+
41
+
42
+ def raise_(ex):
43
+ raise ex
44
+
45
+
46
+ def iter_content_cached(input_dicts_it, llm, schema, cache_target, batch_size, id_column_name, limit_prompt=None,
47
+ **cache_kwargs):
35
48
  assert (isinstance(llm, BaseLM))
36
49
  assert (isinstance(cache_target, str))
37
50
 
@@ -41,23 +54,40 @@ def iter_content_cached(input_dicts_it, llm, schema, cache_target, limit_prompt=
41
54
  if isinstance(schema, dict):
42
55
  schema = SchemaService(json_data=schema)
43
56
 
57
+ # Parse target.
58
+ cache_filepath, _, cache_table = parse_filepath(filepath=cache_target)
59
+
44
60
  # Iterator of the queries.
45
61
  prompts_it = map(
46
62
  lambda data: DictionaryService.custom_update(src_dict=data, other_dict=schema.cot_args),
47
63
  input_dicts_it
48
64
  )
49
65
 
50
- # Parse target.
51
- cache_filepath, _, cache_table = parse_filepath(filepath=cache_target)
66
+ prompts_batched_it = BatchIterator(
67
+ data_iter=iter(tqdm(prompts_it, desc="Iter Content")),
68
+ batch_size=batch_size,
69
+ filter_func=lambda data: not SQLite3Service.entry_exist(
70
+ id_column_name=id_column_name, table_name=cache_table, target=cache_filepath,
71
+ id_value=data[id_column_name], **cache_kwargs)
72
+ )
73
+
74
+ results_it = map(
75
+ lambda batch: infer_batch(
76
+ batch=batch, schema=schema,
77
+ handle_batch_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
78
+ handle_missed_value_func=lambda col_name: raise_(
79
+ Exception(f"Value for {col_name} is undefined. Filling undefined values is not supported")
80
+ )
81
+ ),
82
+ prompts_batched_it
83
+ )
52
84
 
53
85
  # Perform caching first.
54
86
  WRITER_PROVIDERS["sqlite"](
55
- filepath=cache_filepath, table_name=cache_table,
56
- data_it=tqdm(prompts_it, desc="Iter content"),
57
- infer_data_func=lambda c, prompt: _infer_batch(
58
- batch=[prompt], cols=[c],
59
- infer_func=lambda batch: INFER_MODES["default"](llm, batch, limit_prompt),
60
- schema=schema)[0][c],
87
+ filepath=cache_filepath,
88
+ table_name=cache_table,
89
+ data_it=chain.from_iterable(results_it),
90
+ id_column_name=id_column_name,
61
91
  **cache_kwargs)
62
92
 
63
93
  # Then retrieve data.
@@ -76,6 +106,7 @@ if __name__ == '__main__':
76
106
  parser.add_argument('--output', dest='output', type=str, default=None)
77
107
  parser.add_argument('--limit', dest='limit', type=int, default=None,
78
108
  help="Limit amount of source texts for prompting.")
109
+ parser.add_argument('--batch-size', dest='batch_size', type=int, default=1)
79
110
  parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
80
111
  help="Optional trimming prompt by the specified amount of characters.")
81
112
 
@@ -89,7 +120,7 @@ if __name__ == '__main__':
89
120
 
90
121
  # Extract model-related arguments and Initialize Large Language Model.
91
122
  model_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%m", end_prefix="%%")
92
- model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": args.attempts}
123
+ model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": 1}
93
124
  llm, llm_model_name = init_llm(adapter=args.adapter, **model_args_dict)
94
125
 
95
126
  # Setup schema.
@@ -148,6 +179,7 @@ if __name__ == '__main__':
148
179
  limit_prompt=args.limit_prompt,
149
180
  schema=schema,
150
181
  llm=llm,
182
+ batch_size=args.batch_size,
151
183
  id_column_name=args.id_col,
152
184
  cache_target=":".join([cache_filepath, cache_table]))
153
185
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bulk_chain
3
- Version: 0.25.2
3
+ Version: 0.25.3
4
4
  Summary: A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, ensuring reliable results for bulk input requests.
5
5
  Home-page: https://github.com/nicolay-r/bulk-chain
6
6
  Author: Nicolay Rusnachenko
@@ -18,7 +18,7 @@ License-File: LICENSE
18
18
  Requires-Dist: tqdm
19
19
  Requires-Dist: source-iter ==0.24.3
20
20
 
21
- # bulk-chain 0.25.2
21
+ # bulk-chain 0.25.3
22
22
  ![](https://img.shields.io/badge/Python-3.9-brightgreen.svg)
23
23
  [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
24
24
  [![twitter](https://img.shields.io/twitter/url/https/shields.io.svg?style=social)](https://x.com/nicolayr_/status/1847969224636961033)
@@ -0,0 +1,21 @@
1
+ bulk_chain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ bulk_chain/api.py,sha256=8hXJb66bEOf1izgeBmjrB9LexSMJD7GquUhIm76lfmY,4351
3
+ bulk_chain/demo.py,sha256=20r_-ioR3fu3eqHJnCRK4aQmBKTMgjFAHzZPJcXaEz8,3186
4
+ bulk_chain/infer.py,sha256=Qb7ZV3_DXIh1jQoEjER_H2rO4ntREj_iIwGpEpAXJCE,9157
5
+ bulk_chain/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ bulk_chain/core/llm_base.py,sha256=fuWxfEOSRYvnoZMOcfnq1E2LIJKnrpsnxQ1z6SmY1nM,1839
7
+ bulk_chain/core/provider_sqlite.py,sha256=sW4Yefp_zYuL8xVys8la5hG0Ng94jSqiXelPgGGB5B0,5327
8
+ bulk_chain/core/service_args.py,sha256=lq4Veuh4QNu8mlCv8MT9S1rMxTn4FKalyp-3boYonVk,2136
9
+ bulk_chain/core/service_batch.py,sha256=z1OND6x40QBvK2H_Wt8sRS8MadUvYTCHmbS_dCm9t7M,1678
10
+ bulk_chain/core/service_data.py,sha256=OWWHHnr_plwxYTxLuvMrhEc1PbSx-XC3rbFzV0hy3vk,1107
11
+ bulk_chain/core/service_dict.py,sha256=lAghLU-3V3xYGv5BTA327Qcw8UJYmgQRMFdggzlrUgo,383
12
+ bulk_chain/core/service_json.py,sha256=6o1xM_8c9QEjH9Q3qEmJylU9nahfRXhUd5sFF2dGJwo,182
13
+ bulk_chain/core/service_llm.py,sha256=3WYoBgaiqoRwsoKq6VUUNMasbLj5rMCjGkU3OQVxGf8,2278
14
+ bulk_chain/core/service_schema.py,sha256=KIP4n0Tz2h1i7SIMGhgAhoiCgUFXOT1rzMt38yACS2U,1154
15
+ bulk_chain/core/utils.py,sha256=1irk3RGLzJxeLZ-Tv0oOceOKG0ADtpEWniG2UGbYA_U,3089
16
+ bulk_chain/core/utils_logger.py,sha256=BD-ADxaeeuHztaYjqtIY_cIzc5r2Svq9XwRtrgIEqyI,1636
17
+ bulk_chain-0.25.3.dist-info/LICENSE,sha256=VF9SjNpwwSSFEY_eP_8A1ocDCrbwfjI1pZexXdCkOwo,1076
18
+ bulk_chain-0.25.3.dist-info/METADATA,sha256=-yt6rNoJSGqHRe-6z-mdf51wtwpVLYvcr8Yv_1qkdxQ,6037
19
+ bulk_chain-0.25.3.dist-info/WHEEL,sha256=pL8R0wFFS65tNSRnaOVrsw9EOkOqxLrlUPenUYnJKNo,91
20
+ bulk_chain-0.25.3.dist-info/top_level.txt,sha256=Hxq_wyH-GDXKBaA63UfBIiMJO2eCHJG5EOrXDphpeB4,11
21
+ bulk_chain-0.25.3.dist-info/RECORD,,
@@ -1,20 +0,0 @@
1
- bulk_chain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- bulk_chain/api.py,sha256=3q1t4A5wop_BRgYanFCCSQBiGu38P9ds0hTbuxNIUKQ,3590
3
- bulk_chain/demo.py,sha256=3mvgEu03EyDDFzXtpx2fxozLITOn9Lo7ati6H1y54S4,3191
4
- bulk_chain/infer.py,sha256=gq6G48XpOK56g5I_AU2kiQirQgcrZ353kfwjjRfQhSo,8069
5
- bulk_chain/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- bulk_chain/core/llm_base.py,sha256=fuWxfEOSRYvnoZMOcfnq1E2LIJKnrpsnxQ1z6SmY1nM,1839
7
- bulk_chain/core/service_args.py,sha256=lq4Veuh4QNu8mlCv8MT9S1rMxTn4FKalyp-3boYonVk,2136
8
- bulk_chain/core/service_batch.py,sha256=yQr6fbQd4ifQBGMhZMrQQeZpXtDchMKMGJi8XPG7thc,1430
9
- bulk_chain/core/service_data.py,sha256=ZjJDtd1jrQm9hRCXMqe4CT_qF2XDbWBE1lVibP7tAWo,942
10
- bulk_chain/core/service_dict.py,sha256=lAghLU-3V3xYGv5BTA327Qcw8UJYmgQRMFdggzlrUgo,383
11
- bulk_chain/core/service_json.py,sha256=6o1xM_8c9QEjH9Q3qEmJylU9nahfRXhUd5sFF2dGJwo,182
12
- bulk_chain/core/service_llm.py,sha256=0lFqX02-BHI9OOdC-7hZhpsb9QrhCbKE7In3jhKXq3I,3452
13
- bulk_chain/core/service_schema.py,sha256=KIP4n0Tz2h1i7SIMGhgAhoiCgUFXOT1rzMt38yACS2U,1154
14
- bulk_chain/core/utils.py,sha256=UV6Cefaw7yZiYblsCr-s9LsbcI83xe7eESBvha9A2Og,2784
15
- bulk_chain/core/utils_logger.py,sha256=BD-ADxaeeuHztaYjqtIY_cIzc5r2Svq9XwRtrgIEqyI,1636
16
- bulk_chain-0.25.2.dist-info/LICENSE,sha256=VF9SjNpwwSSFEY_eP_8A1ocDCrbwfjI1pZexXdCkOwo,1076
17
- bulk_chain-0.25.2.dist-info/METADATA,sha256=-N7-wOVXryBY1jkARSgWYZUAhdLYZlxkJ8qa8Vuj9no,6037
18
- bulk_chain-0.25.2.dist-info/WHEEL,sha256=pL8R0wFFS65tNSRnaOVrsw9EOkOqxLrlUPenUYnJKNo,91
19
- bulk_chain-0.25.2.dist-info/top_level.txt,sha256=Hxq_wyH-GDXKBaA63UfBIiMJO2eCHJG5EOrXDphpeB4,11
20
- bulk_chain-0.25.2.dist-info/RECORD,,