bulk-chain 0.25.2__tar.gz → 0.25.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/PKG-INFO +2 -2
  2. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/README.md +1 -1
  3. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/api.py +44 -15
  4. bulk_chain-0.25.3/bulk_chain/core/provider_sqlite.py +127 -0
  5. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/core/service_batch.py +8 -6
  6. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/core/service_data.py +9 -5
  7. bulk_chain-0.25.3/bulk_chain/core/service_llm.py +68 -0
  8. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/core/utils.py +15 -3
  9. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/demo.py +1 -1
  10. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/infer.py +46 -14
  11. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain.egg-info/PKG-INFO +2 -2
  12. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain.egg-info/SOURCES.txt +2 -0
  13. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/setup.py +1 -1
  14. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/test/test_api.py +6 -3
  15. bulk_chain-0.25.3/test/test_api_streaming.py +42 -0
  16. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/test/test_provider_batching.py +4 -3
  17. bulk_chain-0.25.2/bulk_chain/core/service_llm.py +0 -94
  18. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/LICENSE +0 -0
  19. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/__init__.py +0 -0
  20. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/core/__init__.py +0 -0
  21. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/core/llm_base.py +0 -0
  22. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/core/service_args.py +0 -0
  23. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/core/service_dict.py +0 -0
  24. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/core/service_json.py +0 -0
  25. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/core/service_schema.py +0 -0
  26. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain/core/utils_logger.py +0 -0
  27. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain.egg-info/dependency_links.txt +0 -0
  28. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain.egg-info/requires.txt +0 -0
  29. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/bulk_chain.egg-info/top_level.txt +0 -0
  30. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/setup.cfg +0 -0
  31. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/test/test.py +0 -0
  32. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/test/test_args_seeking.py +0 -0
  33. {bulk_chain-0.25.2 → bulk_chain-0.25.3}/test/test_cmdargs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bulk_chain
3
- Version: 0.25.2
3
+ Version: 0.25.3
4
4
  Summary: A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, ensuring reliable results for bulk input requests.
5
5
  Home-page: https://github.com/nicolay-r/bulk-chain
6
6
  Author: Nicolay Rusnachenko
@@ -18,7 +18,7 @@ License-File: LICENSE
18
18
  Requires-Dist: tqdm
19
19
  Requires-Dist: source-iter==0.24.3
20
20
 
21
- # bulk-chain 0.25.2
21
+ # bulk-chain 0.25.3
22
22
  ![](https://img.shields.io/badge/Python-3.9-brightgreen.svg)
23
23
  [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
24
24
  [![twitter](https://img.shields.io/twitter/url/https/shields.io.svg?style=social)](https://x.com/nicolayr_/status/1847969224636961033)
@@ -1,4 +1,4 @@
1
- # bulk-chain 0.25.2
1
+ # bulk-chain 0.25.3
2
2
  ![](https://img.shields.io/badge/Python-3.9-brightgreen.svg)
3
3
  [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
4
4
  [![twitter](https://img.shields.io/twitter/url/https/shields.io.svg?style=social)](https://x.com/nicolayr_/status/1847969224636961033)
@@ -1,3 +1,4 @@
1
+ import collections
1
2
  import os
2
3
  from itertools import chain
3
4
 
@@ -9,9 +10,8 @@ from bulk_chain.core.service_json import JsonService
9
10
  from bulk_chain.core.service_schema import SchemaService
10
11
  from bulk_chain.core.utils import dynamic_init, find_by_prefix
11
12
 
13
+
12
14
  INFER_MODES = {
13
- "default": lambda llm, prompt, limit_prompt=None: llm.ask_core(
14
- prompt[:limit_prompt] if limit_prompt is not None else prompt),
15
15
  "batch": lambda llm, batch, limit_prompt=None: llm.ask_core(
16
16
  DataService.limit_prompts(batch, limit=limit_prompt))
17
17
  }
@@ -20,40 +20,67 @@ INFER_MODES = {
20
20
  CWD = os.getcwd()
21
21
 
22
22
 
23
- def _update_batch_content(c, batch, schema, infer_func):
23
+ def _handle_entry(entry, entry_info=None, **kwargs):
24
+
25
+ if isinstance(entry, str):
26
+ kwargs.get("callback_str_func", lambda *_: None)(entry, entry_info)
27
+ return entry
28
+ elif isinstance(entry, collections.abc.Iterable):
29
+ chunks = []
30
+ h = kwargs.get("callback_stream_func", lambda *_: None)
31
+
32
+ h(None, entry_info | {"action": "start"})
33
+
34
+ for chunk in map(lambda item: str(item), entry):
35
+ chunks.append(chunk)
36
+ h(chunk, entry_info)
37
+
38
+ h(None, entry_info | {"action": "end"})
39
+
40
+ return "".join(chunks)
41
+
42
+ raise Exception(f"Non supported type `{type(entry)}` for handling output from batch")
43
+
44
+
45
+ def _update_batch_content(c, batch, schema, **kwargs):
24
46
  assert (isinstance(batch, list))
25
47
  assert (isinstance(c, str))
26
48
 
27
49
  if c in schema.p2r:
28
50
  for batch_item in batch:
29
- batch_item[c] = DataService.get_prompt_text(prompt=batch_item[c]["prompt"], data_dict=batch_item)
51
+ batch_item[c] = DataService.get_prompt_text(
52
+ prompt=batch_item[c]["prompt"],
53
+ data_dict=batch_item,
54
+ handle_missed_func=kwargs["handle_missed_value_func"])
30
55
  if c in schema.r2p:
31
56
  p_column = schema.r2p[c]
32
57
  # This instruction takes a lot of time in a non-batching mode.
33
- BatchService.handle_param_as_batch(batch=batch,
34
- src_param=p_column,
35
- tgt_param=c,
36
- handle_func=lambda b: infer_func(b))
58
+ BatchService.handle_param_as_batch(
59
+ batch=batch,
60
+ src_param=p_column,
61
+ tgt_param=c,
62
+ handle_batch_func=lambda b: kwargs["handle_batch_func"](b),
63
+ handle_entry_func=lambda entry, info: _handle_entry(entry=entry, entry_info=info, **kwargs)
64
+ )
37
65
 
38
66
 
39
- def _infer_batch(batch, schema, infer_func, cols=None):
67
+ def _infer_batch(batch, schema, cols=None, **kwargs):
40
68
  assert (isinstance(batch, list))
41
- assert (callable(infer_func))
42
69
 
43
70
  if len(batch) == 0:
44
71
  return batch
45
72
 
46
73
  if cols is None:
47
74
  first_item = batch[0]
48
- cols = first_item.keys() if cols is None else cols
75
+ cols = list(first_item.keys()) if cols is None else cols
49
76
 
50
77
  for c in cols:
51
- _update_batch_content(c=c, batch=batch, schema=schema, infer_func=infer_func)
78
+ _update_batch_content(c=c, batch=batch, schema=schema, **kwargs)
52
79
 
53
80
  return batch
54
81
 
55
82
 
56
- def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, limit_prompt=None):
83
+ def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, limit_prompt=None, **kwargs):
57
84
  """ This method represent Python API aimed at application of `llm` towards
58
85
  iterator of input_dicts via cache_target that refers to the SQLite using
59
86
  the given `schema`
@@ -72,8 +99,9 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, l
72
99
  )
73
100
 
74
101
  content_it = (_infer_batch(batch=batch,
75
- infer_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
76
- schema=schema)
102
+ handle_batch_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
103
+ schema=schema,
104
+ **kwargs)
77
105
  for batch in BatchIterator(prompts_it, batch_size=batch_size))
78
106
 
79
107
  yield from content_it if return_batch else chain.from_iterable(content_it)
@@ -82,6 +110,7 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, l
82
110
  def init_llm(adapter, **model_kwargs):
83
111
  """ This method perform dynamic initialization of LLM from third-party resource.
84
112
  """
113
+ assert (isinstance(adapter, str))
85
114
 
86
115
  # List of the Supported models and their API wrappers.
87
116
  models_preset = {
@@ -0,0 +1,127 @@
1
+ import sqlite3
2
+
3
+
4
+ class SQLite3Service(object):
5
+
6
+ @staticmethod
7
+ def __create_table(table_name, columns, id_column_name,
8
+ id_column_type, cur, sqlite3_column_types=None):
9
+
10
+ # Setting up default column types.
11
+ if sqlite3_column_types is None:
12
+ types_count = len(columns) if id_column_name in columns else len(columns) - 1
13
+ sqlite3_column_types = ["TEXT"] * types_count
14
+
15
+ # Provide the ID column.
16
+ sqlite3_column_types = [id_column_type] + sqlite3_column_types
17
+
18
+ # Compose the whole columns list.
19
+ content = ", ".join([f"[{item[0]}] {item[1]}" for item in zip(columns, sqlite3_column_types)])
20
+ cur.execute(f"CREATE TABLE IF NOT EXISTS {table_name}({content})")
21
+ cur.execute(f"CREATE INDEX IF NOT EXISTS [{id_column_name}] ON {table_name}([{id_column_name}])")
22
+
23
+ @staticmethod
24
+ def __it_row_lists(cursor):
25
+ for row in cursor:
26
+ yield row
27
+
28
+ @staticmethod
29
+ def create_table_if_not_exist(**kwargs):
30
+ return SQLite3Service.__create_table(**kwargs)
31
+
32
+ @staticmethod
33
+ def entry_exist(table_name, target, id_column_name, id_value, **connect_kwargs) -> bool:
34
+ with sqlite3.connect(target, **connect_kwargs) as con:
35
+ cursor = con.cursor()
36
+
37
+ # Check table existance.
38
+ query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
39
+ cursor.execute(query, (table_name,))
40
+ if cursor.fetchone() is None:
41
+ return False
42
+
43
+ # Check element.
44
+ r = cursor.execute(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE [{id_column_name}]='{id_value}');")
45
+ ans = r.fetchone()[0]
46
+ return ans == 1
47
+
48
+ @staticmethod
49
+ def write(data_it, target, table_name, columns=None, id_column_name="id", data2col_func=None,
50
+ id_column_type="INTEGER", sqlite3_column_types=None, it_type='dict',
51
+ create_table_if_not_exist=True, skip_existed=True, **connect_kwargs):
52
+
53
+ need_set_column_id = True
54
+ need_initialize_columns = columns is None
55
+
56
+ # Setup default columns.
57
+ columns = [] if columns is None else columns
58
+
59
+ with sqlite3.connect(target, **connect_kwargs) as con:
60
+ cur = con.cursor()
61
+
62
+ for content in data_it:
63
+
64
+ if it_type == 'dict':
65
+ # Extracting columns from data.
66
+ data = content
67
+ uid = data[id_column_name]
68
+ row_columns = list(data.keys())
69
+ row_params_func = lambda: [data2col_func(c, data) if data2col_func is not None else data[c]
70
+ for c in row_columns]
71
+ # Append columns if needed.
72
+ if need_initialize_columns:
73
+ columns = list(row_columns)
74
+ elif it_type is None:
75
+ # Setup row columns.
76
+ uid, data = content
77
+ row_columns = columns
78
+ row_params_func = lambda: [uid] + data
79
+ else:
80
+ raise Exception(f"it_type {it_type} does not supported!")
81
+
82
+ if need_set_column_id:
83
+ # Register ID column.
84
+ if id_column_name not in columns:
85
+ columns.append(id_column_name)
86
+ # Place ID column first.
87
+ columns.insert(0, columns.pop(columns.index(id_column_name)))
88
+ need_set_column_id = False
89
+
90
+ if create_table_if_not_exist:
91
+ SQLite3Service.__create_table(
92
+ columns=columns, table_name=table_name, cur=cur,
93
+ id_column_name=id_column_name, id_column_type=id_column_type,
94
+ sqlite3_column_types=sqlite3_column_types)
95
+
96
+ # Check that each rows satisfies criteria of the first row.
97
+ [Exception(f"{column} is expected to be in row!") for column in row_columns if column not in columns]
98
+
99
+ if skip_existed:
100
+ r = cur.execute(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE [{id_column_name}]='{uid}');")
101
+ ans = r.fetchone()[0]
102
+
103
+ if ans == 1:
104
+ continue
105
+
106
+ params = ", ".join(tuple(['?'] * (len(columns))))
107
+ row_columns_str = ", ".join([f"[{col}]" for col in row_columns])
108
+ content_list = row_params_func()
109
+ cur.execute(f"INSERT INTO {table_name}({row_columns_str}) VALUES ({params})", content_list)
110
+ con.commit()
111
+
112
+ cur.close()
113
+
114
+ @staticmethod
115
+ def read(src, table="content", **connect_kwargs):
116
+ with sqlite3.connect(src, **connect_kwargs) as conn:
117
+ cursor = conn.cursor()
118
+ cursor.execute(f"SELECT * FROM {table}")
119
+ for record_list in SQLite3Service.__it_row_lists(cursor):
120
+ yield record_list
121
+
122
+ @staticmethod
123
+ def read_columns(target, table="content", **connect_kwargs):
124
+ with sqlite3.connect(target, **connect_kwargs) as conn:
125
+ cursor = conn.cursor()
126
+ cursor.execute(f"PRAGMA table_info({table})")
127
+ return [row[1] for row in cursor.fetchall()]
@@ -1,31 +1,32 @@
1
1
  class BatchService(object):
2
2
 
3
3
  @staticmethod
4
- def handle_param_as_batch(batch, src_param, tgt_param, handle_func):
4
+ def handle_param_as_batch(batch, src_param, tgt_param, handle_batch_func, handle_entry_func):
5
5
  assert (isinstance(batch, list))
6
6
  assert (isinstance(src_param, str))
7
- assert (callable(handle_func))
7
+ assert (callable(handle_batch_func))
8
8
 
9
9
  _batch = [item[src_param] for item in batch]
10
10
 
11
11
  # Do handling for the batch.
12
- _handled_batch = handle_func(_batch)
12
+ _handled_batch = handle_batch_func(_batch)
13
13
  assert (isinstance(_handled_batch, list))
14
14
 
15
15
  # Apply changes.
16
16
  for i, item in enumerate(batch):
17
- item[tgt_param] = _handled_batch[i]
17
+ item[tgt_param] = handle_entry_func(entry=_handled_batch[i], info={"ind": i, "param": tgt_param})
18
18
 
19
19
 
20
20
  class BatchIterator:
21
21
 
22
- def __init__(self, data_iter, batch_size, end_value=None):
22
+ def __init__(self, data_iter, batch_size, end_value=None, filter_func=None):
23
23
  assert(isinstance(batch_size, int) and batch_size > 0)
24
24
  assert(callable(end_value) or end_value is None)
25
25
  self.__data_iter = data_iter
26
26
  self.__index = 0
27
27
  self.__batch_size = batch_size
28
28
  self.__end_value = end_value
29
+ self.__filter_func = (lambda _: True) if filter_func is None else filter_func
29
30
 
30
31
  def __iter__(self):
31
32
  return self
@@ -37,7 +38,8 @@ class BatchIterator:
37
38
  data = next(self.__data_iter)
38
39
  except StopIteration:
39
40
  break
40
- buffer.append(data)
41
+ if self.__filter_func(data):
42
+ buffer.append(data)
41
43
  if len(buffer) == self.__batch_size:
42
44
  break
43
45
 
@@ -4,8 +4,8 @@ from bulk_chain.core.utils import iter_params
4
4
  class DataService(object):
5
5
 
6
6
  @staticmethod
7
- def compose_prompt_text(prompt, data_dict, field_names):
8
- assert(isinstance(data_dict, dict))
7
+ def __compose_prompt_text(prompt, data_dict, field_names):
8
+ assert (isinstance(data_dict, dict))
9
9
  fmt_d = {col_name: data_dict[col_name] for col_name in field_names}
10
10
 
11
11
  # Guarantee that items has correct type.
@@ -16,10 +16,14 @@ class DataService(object):
16
16
  return prompt.format(**fmt_d)
17
17
 
18
18
  @staticmethod
19
- def get_prompt_text(prompt, data_dict, parse_fields_func=iter_params):
19
+ def get_prompt_text(prompt, data_dict, parse_fields_func=iter_params, handle_missed_func=None):
20
20
  field_names = list(parse_fields_func(prompt))
21
- return DataService.compose_prompt_text(
22
- prompt=prompt, data_dict=data_dict, field_names=field_names)
21
+
22
+ for col_name in field_names:
23
+ if col_name not in data_dict:
24
+ data_dict[col_name] = handle_missed_func(col_name)
25
+
26
+ return DataService.__compose_prompt_text(prompt=prompt, data_dict=data_dict, field_names=field_names)
23
27
 
24
28
  @staticmethod
25
29
  def limit_prompts(prompts_list, limit=None):
@@ -0,0 +1,68 @@
1
+ from bulk_chain.api import iter_content
2
+ from bulk_chain.core.llm_base import BaseLM
3
+ from bulk_chain.core.utils_logger import StreamedLogger
4
+
5
+
6
+ def pad_str(text, pad):
7
+ return text.rjust(len(text) + pad, ' ')
8
+
9
+
10
+ def nice_output(text, remove_new_line=False):
11
+ short_text = text.replace("\n", "") if remove_new_line else text
12
+ return short_text
13
+
14
+
15
+ def chat_with_lm(lm, preset_dict=None, schema=None, model_name=None, pad=0):
16
+ assert (isinstance(lm, BaseLM))
17
+ assert (isinstance(model_name, str) or model_name is None)
18
+
19
+ preset_dict = {} if preset_dict is None else preset_dict
20
+
21
+ streamed_logger = StreamedLogger(__name__)
22
+ do_exit = False
23
+ model_name = model_name if model_name is not None else "agent"
24
+
25
+ while not do_exit:
26
+
27
+ # Launching the CoT engine loop.
28
+ data_dict = {} | preset_dict
29
+
30
+ def callback_str_func(entry, info):
31
+ streamed_logger.info(pad_str(f"{model_name} ({info['param']})->\n", pad=pad))
32
+ streamed_logger.info(nice_output(entry, remove_new_line=False))
33
+ streamed_logger.info("\n\n")
34
+
35
+ def handle_missed_value(col_name):
36
+ user_input = input(f"Enter your prompt for `{col_name}`"
37
+ f"(or 'exit' to quit): ")
38
+
39
+ if user_input.lower() == 'exit':
40
+ exit(0)
41
+
42
+ return user_input
43
+
44
+ def callback_stream_func(entry, info):
45
+ if entry is None and info["action"] == "start":
46
+ streamed_logger.info(pad_str(f"{model_name} ({info['param']})->\n", pad=pad))
47
+ return
48
+ if entry is None and info["action"] == "end":
49
+ streamed_logger.info("\n\n")
50
+ return
51
+
52
+ streamed_logger.info(entry)
53
+
54
+ content_it = iter_content(
55
+ input_dicts_it=[data_dict],
56
+ llm=lm,
57
+ schema=schema,
58
+ batch_size=1,
59
+ return_batch=True,
60
+ handle_missed_value_func=handle_missed_value,
61
+ callback_str_func=callback_str_func,
62
+ callback_stream_func=callback_stream_func,
63
+ )
64
+
65
+ for _ in content_it:
66
+ user_input = input(f"Enter to continue (or 'exit' to quit) ...\n")
67
+ if user_input.lower() == 'exit':
68
+ do_exit = True
@@ -2,6 +2,7 @@ import importlib
2
2
  import logging
3
3
  import sys
4
4
  from collections import Counter
5
+ from os.path import dirname, join, basename
5
6
 
6
7
  logger = logging.getLogger(__name__)
7
8
  logging.basicConfig(level=logging.INFO)
@@ -82,13 +83,24 @@ def auto_import(name, is_class=False):
82
83
 
83
84
 
84
85
  def dynamic_init(class_dir, class_filepath, class_name=None):
85
- sys.path.append(class_dir)
86
+
87
+ # Registering path.
88
+ target = join(class_dir, dirname(class_filepath))
89
+ logger.info(f"Adding sys path for `{target}`")
90
+ sys.path.insert(1, target)
86
91
  class_path_list = class_filepath.split('/')
87
- class_path_list[-1] = '.'.join(class_path_list[-1].split('.')[:-1])
92
+
93
+ # Composing proper class name.
94
+ class_filename = basename(class_path_list[-1])
95
+ if class_filename.endswith(".py"):
96
+ class_filename = class_filename[:-len(".py")]
97
+
98
+ # Loading library.
88
99
  class_name = class_path_list[-1].title() if class_name is None else class_name
89
- class_path = ".".join(class_path_list + [class_name])
100
+ class_path = ".".join([class_filename, class_name])
90
101
  logger.info(f"Dynamic loading for the file and class `{class_path}`")
91
102
  cls = auto_import(class_path, is_class=False)
103
+
92
104
  return cls
93
105
 
94
106
 
@@ -81,4 +81,4 @@ if __name__ == '__main__':
81
81
  preset_dict[key] = value
82
82
 
83
83
  # Launch Demo.
84
- chat_with_lm(llm, preset_dict=preset_dict, chain=schema.chain, model_name=llm_model_name)
84
+ chat_with_lm(llm, preset_dict=preset_dict, schema=schema, model_name=llm_model_name)
@@ -1,3 +1,4 @@
1
+ from itertools import chain
1
2
  from os.path import join, basename
2
3
 
3
4
  import argparse
@@ -6,12 +7,13 @@ import sys
6
7
 
7
8
  from source_iter.service_csv import CsvService
8
9
  from source_iter.service_jsonl import JsonlService
9
- from source_iter.service_sqlite import SQLite3Service
10
10
  from tqdm import tqdm
11
11
 
12
12
  from bulk_chain.api import INFER_MODES, _infer_batch, CWD, init_llm
13
13
  from bulk_chain.core.llm_base import BaseLM
14
+ from bulk_chain.core.provider_sqlite import SQLite3Service
14
15
  from bulk_chain.core.service_args import CmdArgsService
16
+ from bulk_chain.core.service_batch import BatchIterator
15
17
  from bulk_chain.core.service_dict import DictionaryService
16
18
  from bulk_chain.core.service_json import JsonService
17
19
  from bulk_chain.core.service_schema import SchemaService
@@ -21,9 +23,8 @@ logger = logging.getLogger(__name__)
21
23
  logging.basicConfig(level=logging.INFO)
22
24
 
23
25
  WRITER_PROVIDERS = {
24
- "sqlite": lambda filepath, table_name, data_it, infer_data_func, **kwargs: SQLite3Service.write(
25
- data_it=data_it, target=filepath, table_name=table_name, data2col_func=infer_data_func,
26
- skip_existed=True, **kwargs)
26
+ "sqlite": lambda filepath, table_name, data_it, **kwargs: SQLite3Service.write(
27
+ data_it=data_it, target=filepath, table_name=table_name, skip_existed=True, **kwargs)
27
28
  }
28
29
 
29
30
  READER_PROVIDERS = {
@@ -31,7 +32,19 @@ READER_PROVIDERS = {
31
32
  }
32
33
 
33
34
 
34
- def iter_content_cached(input_dicts_it, llm, schema, cache_target, limit_prompt=None, **cache_kwargs):
35
+ def infer_batch(batch, columns=None, **kwargs):
36
+ assert (len(batch) > 0)
37
+ # TODO. Support proper selection of columns.
38
+ cols = batch[0].keys() if columns is None else columns
39
+ return _infer_batch(batch=batch, cols=cols, **kwargs)
40
+
41
+
42
+ def raise_(ex):
43
+ raise ex
44
+
45
+
46
+ def iter_content_cached(input_dicts_it, llm, schema, cache_target, batch_size, id_column_name, limit_prompt=None,
47
+ **cache_kwargs):
35
48
  assert (isinstance(llm, BaseLM))
36
49
  assert (isinstance(cache_target, str))
37
50
 
@@ -41,23 +54,40 @@ def iter_content_cached(input_dicts_it, llm, schema, cache_target, limit_prompt=
41
54
  if isinstance(schema, dict):
42
55
  schema = SchemaService(json_data=schema)
43
56
 
57
+ # Parse target.
58
+ cache_filepath, _, cache_table = parse_filepath(filepath=cache_target)
59
+
44
60
  # Iterator of the queries.
45
61
  prompts_it = map(
46
62
  lambda data: DictionaryService.custom_update(src_dict=data, other_dict=schema.cot_args),
47
63
  input_dicts_it
48
64
  )
49
65
 
50
- # Parse target.
51
- cache_filepath, _, cache_table = parse_filepath(filepath=cache_target)
66
+ prompts_batched_it = BatchIterator(
67
+ data_iter=iter(tqdm(prompts_it, desc="Iter Content")),
68
+ batch_size=batch_size,
69
+ filter_func=lambda data: not SQLite3Service.entry_exist(
70
+ id_column_name=id_column_name, table_name=cache_table, target=cache_filepath,
71
+ id_value=data[id_column_name], **cache_kwargs)
72
+ )
73
+
74
+ results_it = map(
75
+ lambda batch: infer_batch(
76
+ batch=batch, schema=schema,
77
+ handle_batch_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
78
+ handle_missed_value_func=lambda col_name: raise_(
79
+ Exception(f"Value for {col_name} is undefined. Filling undefined values is not supported")
80
+ )
81
+ ),
82
+ prompts_batched_it
83
+ )
52
84
 
53
85
  # Perform caching first.
54
86
  WRITER_PROVIDERS["sqlite"](
55
- filepath=cache_filepath, table_name=cache_table,
56
- data_it=tqdm(prompts_it, desc="Iter content"),
57
- infer_data_func=lambda c, prompt: _infer_batch(
58
- batch=[prompt], cols=[c],
59
- infer_func=lambda batch: INFER_MODES["default"](llm, batch, limit_prompt),
60
- schema=schema)[0][c],
87
+ filepath=cache_filepath,
88
+ table_name=cache_table,
89
+ data_it=chain.from_iterable(results_it),
90
+ id_column_name=id_column_name,
61
91
  **cache_kwargs)
62
92
 
63
93
  # Then retrieve data.
@@ -76,6 +106,7 @@ if __name__ == '__main__':
76
106
  parser.add_argument('--output', dest='output', type=str, default=None)
77
107
  parser.add_argument('--limit', dest='limit', type=int, default=None,
78
108
  help="Limit amount of source texts for prompting.")
109
+ parser.add_argument('--batch-size', dest='batch_size', type=int, default=1)
79
110
  parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
80
111
  help="Optional trimming prompt by the specified amount of characters.")
81
112
 
@@ -89,7 +120,7 @@ if __name__ == '__main__':
89
120
 
90
121
  # Extract model-related arguments and Initialize Large Language Model.
91
122
  model_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%m", end_prefix="%%")
92
- model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": args.attempts}
123
+ model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": 1}
93
124
  llm, llm_model_name = init_llm(adapter=args.adapter, **model_args_dict)
94
125
 
95
126
  # Setup schema.
@@ -148,6 +179,7 @@ if __name__ == '__main__':
148
179
  limit_prompt=args.limit_prompt,
149
180
  schema=schema,
150
181
  llm=llm,
182
+ batch_size=args.batch_size,
151
183
  id_column_name=args.id_col,
152
184
  cache_target=":".join([cache_filepath, cache_table]))
153
185
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bulk_chain
3
- Version: 0.25.2
3
+ Version: 0.25.3
4
4
  Summary: A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, ensuring reliable results for bulk input requests.
5
5
  Home-page: https://github.com/nicolay-r/bulk-chain
6
6
  Author: Nicolay Rusnachenko
@@ -18,7 +18,7 @@ License-File: LICENSE
18
18
  Requires-Dist: tqdm
19
19
  Requires-Dist: source-iter==0.24.3
20
20
 
21
- # bulk-chain 0.25.2
21
+ # bulk-chain 0.25.3
22
22
  ![](https://img.shields.io/badge/Python-3.9-brightgreen.svg)
23
23
  [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
24
24
  [![twitter](https://img.shields.io/twitter/url/https/shields.io.svg?style=social)](https://x.com/nicolayr_/status/1847969224636961033)
@@ -12,6 +12,7 @@ bulk_chain.egg-info/requires.txt
12
12
  bulk_chain.egg-info/top_level.txt
13
13
  bulk_chain/core/__init__.py
14
14
  bulk_chain/core/llm_base.py
15
+ bulk_chain/core/provider_sqlite.py
15
16
  bulk_chain/core/service_args.py
16
17
  bulk_chain/core/service_batch.py
17
18
  bulk_chain/core/service_data.py
@@ -23,6 +24,7 @@ bulk_chain/core/utils.py
23
24
  bulk_chain/core/utils_logger.py
24
25
  test/test.py
25
26
  test/test_api.py
27
+ test/test_api_streaming.py
26
28
  test/test_args_seeking.py
27
29
  test/test_cmdargs.py
28
30
  test/test_provider_batching.py
@@ -15,7 +15,7 @@ def get_requirements(filenames):
15
15
 
16
16
  setup(
17
17
  name='bulk_chain',
18
- version='0.25.2',
18
+ version='0.25.3',
19
19
  python_requires=">=3.6",
20
20
  description='A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, '
21
21
  'ensuring reliable results for bulk input requests.',
@@ -4,13 +4,14 @@ from os.path import join
4
4
  from bulk_chain.api import iter_content, CWD
5
5
  from bulk_chain.core.utils import dynamic_init
6
6
  from bulk_chain.infer import iter_content_cached
7
+ from utils import current_dir, API_TOKEN
7
8
 
8
9
 
9
10
  class TestAPI(unittest.TestCase):
10
11
 
11
12
  llm = dynamic_init(class_dir=join(CWD, ".."),
12
13
  class_filepath="providers/replicate_104.py",
13
- class_name="Replicate")(api_token="<API-KEY>",
14
+ class_name="Replicate")(api_token=API_TOKEN,
14
15
  model_name="deepseek-ai/deepseek-r1")
15
16
 
16
17
  def it_data(self, n):
@@ -20,7 +21,8 @@ class TestAPI(unittest.TestCase):
20
21
  def test_iter_cached(self):
21
22
  data_it = iter_content_cached(input_dicts_it=self.it_data(20),
22
23
  llm=self.llm,
23
- schema="../schema/default.json",
24
+ batch_size=1,
25
+ schema=join(current_dir, "schema/default.json"),
24
26
  # Cache-related extra parameters.
25
27
  cache_target="out.sqlite:content",
26
28
  id_column_name="ind")
@@ -33,7 +35,8 @@ class TestAPI(unittest.TestCase):
33
35
  llm=self.llm,
34
36
  batch_size=1,
35
37
  return_batch=True,
36
- schema="../schema/default.json")
38
+ handle_missed_value_func=lambda *_: None,
39
+ schema=join(current_dir, "schema/default.json"))
37
40
 
38
41
  for data in data_it:
39
42
  print(data)
@@ -0,0 +1,42 @@
1
+ import unittest
2
+ from os.path import join
3
+
4
+ from tqdm import tqdm
5
+
6
+ from bulk_chain.api import CWD, iter_content
7
+ from bulk_chain.core.utils import dynamic_init
8
+ from bulk_chain.core.utils_logger import StreamedLogger
9
+ from utils import API_TOKEN, iter_test_jsonl_samples
10
+
11
+
12
+ class TestAPI_Streaming(unittest.TestCase):
13
+
14
+ llm = dynamic_init(class_dir=join(CWD, ".."),
15
+ class_filepath="providers/replicate_104.py",
16
+ class_name="Replicate")(api_token=API_TOKEN,
17
+ model_name="meta/meta-llama-3-70b-instruct",
18
+ stream=True)
19
+
20
+ def test_iter(self):
21
+
22
+ streamed_logger = StreamedLogger(__name__)
23
+
24
+ def callback(chunk, info):
25
+ if chunk is None and info["action"] == "start":
26
+ streamed_logger.info(f"\n{info['param']} (batch_ind={info['ind']}):\n")
27
+ return
28
+ if chunk is None and info["action"] == "end":
29
+ streamed_logger.info("\n\n")
30
+ return
31
+ streamed_logger.info(chunk)
32
+
33
+ input_dicts_it = iter_test_jsonl_samples()
34
+ data_it = iter_content(input_dicts_it=input_dicts_it,
35
+ llm=self.llm,
36
+ return_batch=False,
37
+ callback_stream_func=callback,
38
+ handle_missed_value_func=lambda *_: None,
39
+ schema="schema/thor_cot_schema.json")
40
+
41
+ for _ in tqdm(data_it):
42
+ streamed_logger.info("\n|NEXT ENTRY|\n")
@@ -10,7 +10,7 @@ from utils import iter_test_jsonl_samples
10
10
 
11
11
  class TestProviderBatching(unittest.TestCase):
12
12
 
13
- llm = dynamic_init(class_dir=join(CWD, ".."),
13
+ llm = dynamic_init(class_dir=join(CWD),
14
14
  class_filepath="providers/transformers_flan_t5.py",
15
15
  class_name="FlanT5")(model_name="nicolay-r/flan-t5-tsa-thor-base",
16
16
  max_new_tokens=128)
@@ -19,9 +19,10 @@ class TestProviderBatching(unittest.TestCase):
19
19
  input_dicts_it = iter_test_jsonl_samples()
20
20
  data_it = iter_content(input_dicts_it=input_dicts_it,
21
21
  llm=self.llm,
22
- batch_size=20,
22
+ batch_size=10,
23
23
  return_batch=False,
24
- schema="schema/default.json")
24
+ handle_missed_value_func=lambda *_: None,
25
+ schema="schema/thor_cot_schema.json")
25
26
 
26
27
  for item in tqdm(data_it):
27
28
  print(item)
@@ -1,94 +0,0 @@
1
- from bulk_chain.core.llm_base import BaseLM
2
- from bulk_chain.core.service_data import DataService
3
- from bulk_chain.core.utils import iter_params
4
- from bulk_chain.core.utils_logger import StreamedLogger
5
-
6
-
7
- def pad_str(text, pad):
8
- return text.rjust(len(text) + pad, ' ')
9
-
10
-
11
- def nice_output(text, remove_new_line=False):
12
- short_text = text.replace("\n", "") if remove_new_line else text
13
- return short_text
14
-
15
-
16
- def chat_with_lm(lm, preset_dict=None, chain=None, model_name=None, pad=0):
17
- assert (isinstance(lm, BaseLM))
18
- assert (isinstance(chain, list))
19
- assert (isinstance(model_name, str) or model_name is None)
20
-
21
- preset_dict = {} if preset_dict is None else preset_dict
22
-
23
- streamed_logger = StreamedLogger(__name__)
24
-
25
- do_exit = False
26
- model_name = model_name if model_name is not None else "agent"
27
-
28
- while not do_exit:
29
-
30
- streamed_logger.info("----------------")
31
- streamed_logger.info("\n")
32
-
33
- # Launching the CoT engine loop.
34
- data_dict = {} | preset_dict
35
- for chain_ind, prompt_args in enumerate(chain):
36
-
37
- # Processing the prompt.
38
- prompt = prompt_args["prompt"]
39
-
40
- # Filling necessary parameters.
41
- user_informed = False
42
- field_names = list(iter_params(prompt))
43
- for ind, f_name in enumerate(field_names):
44
-
45
- if f_name in data_dict:
46
- continue
47
-
48
- user_input = input(f"Enter your prompt for `{f_name}` ({ind+1}/{len(field_names)}) "
49
- f"(or 'exit' to quit): ")
50
- user_informed = True
51
-
52
- if user_input.lower() == 'exit':
53
- do_exit = True
54
- break
55
-
56
- data_dict[f_name] = user_input
57
-
58
- if do_exit:
59
- break
60
-
61
- # In the case of the initial interaction with the chain.
62
- # we make sure that aware user for starting interaction.
63
- if chain_ind == 0 and not user_informed:
64
- user_input = input(f"Enter to continue (or 'exit' to quit) ...")
65
- if user_input.lower() == 'exit':
66
- do_exit = True
67
-
68
- # Finally asking LLM.
69
- DataService.compose_prompt_text(prompt=prompt, data_dict=data_dict, field_names=field_names)
70
- actual_prompt = DataService.get_prompt_text(prompt=prompt, data_dict=data_dict)
71
-
72
- # Returning meta information, passed to LLM.
73
- streamed_logger.info(pad_str(f"{model_name} (ask [{chain_ind+1}/{len(chain)}]) ->", pad=pad))
74
- streamed_logger.info("\n")
75
- streamed_logger.info(nice_output(actual_prompt, remove_new_line=True))
76
- streamed_logger.info("\n\n")
77
-
78
- # Response.
79
- response = lm.ask_core(batch=[actual_prompt])[0]
80
- streamed_logger.info(pad_str(f"{model_name} (resp [{chain_ind+1}/{len(chain)}])->", pad=pad))
81
- streamed_logger.info("\n")
82
- if isinstance(response, str):
83
- streamed_logger.info(nice_output(response, remove_new_line=False))
84
- buffer = [response]
85
- else:
86
- buffer = []
87
- for chunk in response:
88
- streamed_logger.info(chunk)
89
- buffer.append(str(chunk))
90
-
91
- streamed_logger.info("\n\n")
92
-
93
- # Collecting the answer for the next turn.
94
- data_dict[prompt_args["out"]] = "".join(buffer)
File without changes
File without changes
File without changes