bulk-chain 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bulk_chain/__init__.py +0 -0
- bulk_chain/core/__init__.py +0 -0
- bulk_chain/core/llm_base.py +13 -0
- bulk_chain/core/provider_sqlite.py +79 -0
- bulk_chain/core/service_args.py +48 -0
- bulk_chain/core/service_csv.py +57 -0
- bulk_chain/core/service_data.py +22 -0
- bulk_chain/core/service_json.py +26 -0
- bulk_chain/core/service_llm.py +82 -0
- bulk_chain/core/service_schema.py +34 -0
- bulk_chain/core/utils.py +101 -0
- bulk_chain/infer.py +170 -0
- bulk_chain-0.24.0.dist-info/LICENSE +21 -0
- bulk_chain-0.24.0.dist-info/METADATA +92 -0
- bulk_chain-0.24.0.dist-info/RECORD +17 -0
- bulk_chain-0.24.0.dist-info/WHEEL +5 -0
- bulk_chain-0.24.0.dist-info/top_level.txt +1 -0
bulk_chain/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import sqlite3
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SQLiteProvider(object):
|
|
5
|
+
|
|
6
|
+
@staticmethod
|
|
7
|
+
def __create_table(table_name, columns, id_column_name,
|
|
8
|
+
id_column_type, sqlite3_column_types, cur):
|
|
9
|
+
|
|
10
|
+
# Provide the ID column.
|
|
11
|
+
sqlite3_column_types = [id_column_type] + sqlite3_column_types
|
|
12
|
+
|
|
13
|
+
# Compose the whole columns list.
|
|
14
|
+
content = ", ".join([" ".join(item) for item in zip(columns, sqlite3_column_types)])
|
|
15
|
+
cur.execute(f"CREATE TABLE IF NOT EXISTS {table_name}({content})")
|
|
16
|
+
cur.execute(f"CREATE INDEX IF NOT EXISTS i_id ON {table_name}({id_column_name})")
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def write_auto(data_it, target, data2col_func, table_name, id_column_name="id",
|
|
20
|
+
id_column_type="INTEGER"):
|
|
21
|
+
""" NOTE: data_it is an iterator of dictionaries.
|
|
22
|
+
This implementation automatically creates the table and
|
|
23
|
+
"""
|
|
24
|
+
with sqlite3.connect(target) as con:
|
|
25
|
+
cur = con.cursor()
|
|
26
|
+
|
|
27
|
+
columns = None
|
|
28
|
+
for data in data_it:
|
|
29
|
+
assert(isinstance(data, dict))
|
|
30
|
+
|
|
31
|
+
# Extracting columns from data.
|
|
32
|
+
row_columns = list(data.keys())
|
|
33
|
+
assert(id_column_name in row_columns)
|
|
34
|
+
|
|
35
|
+
# Optionally create table.
|
|
36
|
+
if columns is None:
|
|
37
|
+
|
|
38
|
+
# Setup list of columns.
|
|
39
|
+
columns = row_columns
|
|
40
|
+
# Place ID column first.
|
|
41
|
+
columns.insert(0, columns.pop(columns.index(id_column_name)))
|
|
42
|
+
|
|
43
|
+
SQLiteProvider.__create_table(
|
|
44
|
+
columns=columns, table_name=table_name, cur=cur,
|
|
45
|
+
id_column_name=id_column_name, id_column_type=id_column_type,
|
|
46
|
+
sqlite3_column_types=["TEXT"] * len(columns))
|
|
47
|
+
|
|
48
|
+
# Check that each rows satisfies criteria of the first row.
|
|
49
|
+
[Exception(f"{column} is expected to be in row!") for column in row_columns if column not in columns]
|
|
50
|
+
|
|
51
|
+
uid = data[id_column_name]
|
|
52
|
+
r = cur.execute(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE {id_column_name}='{uid}');")
|
|
53
|
+
ans = r.fetchone()[0]
|
|
54
|
+
if ans == 1:
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
params = ", ".join(tuple(['?'] * (len(columns))))
|
|
58
|
+
row_columns_str = ", ".join(row_columns)
|
|
59
|
+
cur.execute(f"INSERT INTO {table_name}({row_columns_str}) VALUES ({params})",
|
|
60
|
+
[data2col_func(c, data) for c in row_columns])
|
|
61
|
+
con.commit()
|
|
62
|
+
|
|
63
|
+
cur.close()
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def read(target, column_names=None, table="content"):
|
|
67
|
+
with sqlite3.connect(target) as conn:
|
|
68
|
+
cursor = conn.cursor()
|
|
69
|
+
cols = "*" if column_names is None else ",".join(column_names)
|
|
70
|
+
cursor.execute(f"SELECT {cols} FROM {table}")
|
|
71
|
+
for row in cursor:
|
|
72
|
+
yield row
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def get_columns(target, table="content"):
|
|
76
|
+
with sqlite3.connect(target) as conn:
|
|
77
|
+
cursor = conn.cursor()
|
|
78
|
+
cursor.execute(f"PRAGMA table_info({table})")
|
|
79
|
+
return [row[1] for row in cursor.fetchall()]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
class CmdArgsService:
|
|
2
|
+
|
|
3
|
+
@staticmethod
|
|
4
|
+
def autocast(v):
|
|
5
|
+
for t in [int, float, str]:
|
|
6
|
+
try:
|
|
7
|
+
return t(v)
|
|
8
|
+
except:
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def iter_arguments(lst):
|
|
13
|
+
|
|
14
|
+
def __release():
|
|
15
|
+
return key, buf if len(buf) > 1 else buf[0]
|
|
16
|
+
|
|
17
|
+
key = None
|
|
18
|
+
buf = []
|
|
19
|
+
for a in lst:
|
|
20
|
+
if a.startswith('--'):
|
|
21
|
+
# release
|
|
22
|
+
if key is not None:
|
|
23
|
+
yield __release()
|
|
24
|
+
# set new key and empty buf
|
|
25
|
+
key = a[2:]
|
|
26
|
+
buf = []
|
|
27
|
+
else:
|
|
28
|
+
# append argument into buffer.
|
|
29
|
+
buf.append(a)
|
|
30
|
+
|
|
31
|
+
# Sharing the remaining params.
|
|
32
|
+
if len(buf) > 0:
|
|
33
|
+
yield __release()
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def partition_list(lst, sep):
|
|
37
|
+
"""Slices a list in two, cutting on index matching "sep"
|
|
38
|
+
"""
|
|
39
|
+
if sep in lst:
|
|
40
|
+
idx = lst.index(sep)
|
|
41
|
+
return (lst[:idx], lst[idx+1:])
|
|
42
|
+
else:
|
|
43
|
+
return (lst[:], None)
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def args_to_dict(args):
|
|
47
|
+
return {k: CmdArgsService.autocast(v) if not isinstance(v, list) else v
|
|
48
|
+
for k, v in CmdArgsService.iter_arguments(args)} if args is not None else {}
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger(__name__)
|
|
5
|
+
logging.basicConfig(level=logging.INFO)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CsvService:
|
|
9
|
+
|
|
10
|
+
@staticmethod
|
|
11
|
+
def write(target, lines_it):
|
|
12
|
+
f = open(target, "w")
|
|
13
|
+
logger.info(f"Saving: {target}")
|
|
14
|
+
w = csv.writer(f, delimiter="\t", quotechar='"', quoting=csv.QUOTE_MINIMAL)
|
|
15
|
+
for content in lines_it:
|
|
16
|
+
w.writerow(content)
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def write_handled(target, data_it, data2col_func, header):
|
|
20
|
+
|
|
21
|
+
def __it():
|
|
22
|
+
yield header
|
|
23
|
+
for data in data_it:
|
|
24
|
+
content = data2col_func(data)
|
|
25
|
+
assert(len(content) == len(header))
|
|
26
|
+
yield content
|
|
27
|
+
|
|
28
|
+
CsvService.write(target, lines_it=__it())
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def read(target, skip_header=False, cols=None, as_dict=False, row_id_key=None, **csv_kwargs):
|
|
32
|
+
assert (isinstance(row_id_key, str) or row_id_key is None)
|
|
33
|
+
assert (isinstance(cols, list) or cols is None)
|
|
34
|
+
|
|
35
|
+
header = None
|
|
36
|
+
with open(target, newline='\n') as f:
|
|
37
|
+
for row_id, row in enumerate(csv.reader(f, **csv_kwargs)):
|
|
38
|
+
if skip_header and row_id == 0:
|
|
39
|
+
header = ([row_id_key] if row_id_key is not None else []) + row
|
|
40
|
+
continue
|
|
41
|
+
|
|
42
|
+
# Determine the content we wish to return.
|
|
43
|
+
if cols is None:
|
|
44
|
+
content = row
|
|
45
|
+
else:
|
|
46
|
+
row_d = {header[col_ind]: value for col_ind, value in enumerate(row)}
|
|
47
|
+
content = [row_d[col_name] for col_name in cols]
|
|
48
|
+
|
|
49
|
+
content = ([row_id-1] if row_id_key is not None else []) + content
|
|
50
|
+
|
|
51
|
+
# Optionally attach row_id to the content.
|
|
52
|
+
if as_dict:
|
|
53
|
+
assert (header is not None)
|
|
54
|
+
assert (len(content) == len(header))
|
|
55
|
+
yield {k: v for k, v in zip(header, content)}
|
|
56
|
+
else:
|
|
57
|
+
yield content
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from bulk_chain.core.utils import iter_params
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DataService(object):
|
|
5
|
+
|
|
6
|
+
@staticmethod
|
|
7
|
+
def compose_prompt_text(prompt, data_dict, field_names):
|
|
8
|
+
assert(isinstance(data_dict, dict))
|
|
9
|
+
fmt_d = {col_name: data_dict[col_name] for col_name in field_names}
|
|
10
|
+
|
|
11
|
+
# Guarantee that items has correct type.
|
|
12
|
+
for k, v in fmt_d.items():
|
|
13
|
+
if not isinstance(v, str):
|
|
14
|
+
Exception("'{k}' parameter is expected to be string, but received '{v}'")
|
|
15
|
+
|
|
16
|
+
return prompt.format(**fmt_d)
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def get_prompt_text(prompt, data_dict, parse_fields_func=iter_params):
|
|
20
|
+
field_names = list(parse_fields_func(prompt))
|
|
21
|
+
return DataService.compose_prompt_text(
|
|
22
|
+
prompt=prompt, data_dict=data_dict, field_names=field_names)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class JsonService(object):
|
|
5
|
+
|
|
6
|
+
@staticmethod
|
|
7
|
+
def read_data(src):
|
|
8
|
+
assert (isinstance(src, str))
|
|
9
|
+
with open(src, "r") as f:
|
|
10
|
+
return json.load(f)
|
|
11
|
+
|
|
12
|
+
@staticmethod
|
|
13
|
+
def read_lines(src, row_id_key=None):
|
|
14
|
+
assert (isinstance(src, str))
|
|
15
|
+
with open(src, "r") as f:
|
|
16
|
+
for line_ind, line in enumerate(f.readlines()):
|
|
17
|
+
content = json.loads(line)
|
|
18
|
+
if row_id_key is not None:
|
|
19
|
+
content[row_id_key] = line_ind
|
|
20
|
+
yield content
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def write_lines(target, data_it):
|
|
24
|
+
with open(target, "w") as f:
|
|
25
|
+
for item in data_it:
|
|
26
|
+
f.write(f"{json.dumps(item, ensure_ascii=False)}\n")
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from bulk_chain.core.llm_base import BaseLM
|
|
4
|
+
from bulk_chain.core.service_data import DataService
|
|
5
|
+
from bulk_chain.core.utils import iter_params
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
logging.basicConfig(level=logging.INFO)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def pad_str(text, pad):
|
|
12
|
+
return text.rjust(len(text) + pad, ' ')
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def text_wrap(content, width, handle_line=lambda l: l):
|
|
16
|
+
lines = []
|
|
17
|
+
for text in content.split('\n'):
|
|
18
|
+
for i in range(0, len(text), width):
|
|
19
|
+
line = handle_line(text[i:i + width])
|
|
20
|
+
lines.append(line)
|
|
21
|
+
return '\n'.join(lines)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def nice_output(text, width, pad=4, remove_new_line=False):
|
|
25
|
+
short_text = text.replace("\n", "") if remove_new_line else text
|
|
26
|
+
return text_wrap(content=short_text, width=width, handle_line=lambda line: pad_str(line, pad=pad))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def chat_with_lm(lm, chain=None, model_name=None):
|
|
30
|
+
assert(isinstance(lm, BaseLM))
|
|
31
|
+
assert(isinstance(chain, list))
|
|
32
|
+
assert(isinstance(model_name, str) or model_name is None)
|
|
33
|
+
|
|
34
|
+
do_exit = False
|
|
35
|
+
model_name = model_name if model_name is not None else "agent"
|
|
36
|
+
|
|
37
|
+
while not do_exit:
|
|
38
|
+
|
|
39
|
+
logger.info("----------------")
|
|
40
|
+
|
|
41
|
+
# Launching the CoT engine loop.
|
|
42
|
+
data_dict = {}
|
|
43
|
+
for prompt_args in chain:
|
|
44
|
+
|
|
45
|
+
# Processing the prompt.
|
|
46
|
+
prompt = prompt_args["prompt"]
|
|
47
|
+
|
|
48
|
+
# Filling necessary parameters.
|
|
49
|
+
field_names = list(iter_params(prompt))
|
|
50
|
+
for ind, f_name in enumerate(field_names):
|
|
51
|
+
|
|
52
|
+
if f_name in data_dict:
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
user_input = input(f"Enter your prompt for `{f_name}` ({ind+1}/{len(field_names)}) "
|
|
56
|
+
f"(or 'exit' to quit): ")
|
|
57
|
+
|
|
58
|
+
if user_input.lower() == 'exit':
|
|
59
|
+
do_exit = True
|
|
60
|
+
break
|
|
61
|
+
|
|
62
|
+
data_dict[f_name] = user_input
|
|
63
|
+
|
|
64
|
+
if do_exit:
|
|
65
|
+
break
|
|
66
|
+
|
|
67
|
+
# Finally asking LLM.
|
|
68
|
+
DataService.compose_prompt_text(prompt=prompt, data_dict=data_dict, field_names=field_names)
|
|
69
|
+
actual_prompt = DataService.get_prompt_text(prompt=prompt, data_dict=data_dict)
|
|
70
|
+
|
|
71
|
+
# Returning meta information, passed to LLM.
|
|
72
|
+
pad = 4
|
|
73
|
+
logger.info(pad_str(f"{model_name} (ask) ->", pad=pad))
|
|
74
|
+
logger.info(nice_output(actual_prompt, pad=pad*2, remove_new_line=True, width=80))
|
|
75
|
+
|
|
76
|
+
# Response.
|
|
77
|
+
response = lm.ask(actual_prompt)
|
|
78
|
+
logger.info(pad_str(f"{model_name} (resp)->", pad=pad))
|
|
79
|
+
logger.info(nice_output(response, pad=pad*2, remove_new_line=False, width=80))
|
|
80
|
+
|
|
81
|
+
# Collecting the answer for the next turn.
|
|
82
|
+
data_dict[prompt_args["out"]] = response
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
class SchemaService(object):
|
|
2
|
+
|
|
3
|
+
def __init__(self, json_data):
|
|
4
|
+
self.src = json_data
|
|
5
|
+
self.name = self.src["name"]
|
|
6
|
+
self.r2p, self.p2r, self.cot_args, self.chain = SchemaService.__init_schema(prompts=json_data["schema"])
|
|
7
|
+
|
|
8
|
+
@classmethod
|
|
9
|
+
def from_prompt(cls, prompt):
|
|
10
|
+
prompt_schema = {"name": "prompt", "schema": [{"prompt": prompt, "out": "response", "in": "prompt"}]}
|
|
11
|
+
return cls(prompt_schema)
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def __init_schema(prompts):
|
|
15
|
+
|
|
16
|
+
schema_args = {}
|
|
17
|
+
schema_r2p = {}
|
|
18
|
+
schema_p2r = {}
|
|
19
|
+
chain = []
|
|
20
|
+
|
|
21
|
+
for prompt in prompts:
|
|
22
|
+
r_col_name = prompt["out"]
|
|
23
|
+
p_col_name = r_col_name + "_prompt" if "in" not in prompt else prompt["in"]
|
|
24
|
+
|
|
25
|
+
assert r_col_name not in schema_r2p, f"`{r_col_name}` has been already declared!"
|
|
26
|
+
assert p_col_name not in schema_p2r, f"`{p_col_name}` has been already declared!"
|
|
27
|
+
|
|
28
|
+
schema_r2p[r_col_name] = p_col_name
|
|
29
|
+
schema_p2r[p_col_name] = r_col_name
|
|
30
|
+
schema_args[p_col_name] = prompt
|
|
31
|
+
schema_args[r_col_name] = None
|
|
32
|
+
chain.append(prompt)
|
|
33
|
+
|
|
34
|
+
return schema_r2p, schema_p2r, schema_args, chain
|
bulk_chain/core/utils.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import logging
|
|
3
|
+
import sys
|
|
4
|
+
from collections import Counter
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
logging.basicConfig(level=logging.INFO)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def find_by_prefix(d, key):
|
|
11
|
+
"""
|
|
12
|
+
d: dict (str, val).
|
|
13
|
+
"""
|
|
14
|
+
assert(isinstance(d, dict))
|
|
15
|
+
assert(isinstance(key, str))
|
|
16
|
+
|
|
17
|
+
# We first check the full match.
|
|
18
|
+
for k, value in d.items():
|
|
19
|
+
if k == key:
|
|
20
|
+
return value
|
|
21
|
+
|
|
22
|
+
# If we can't establish full match, then we seek by prefix.
|
|
23
|
+
matches = []
|
|
24
|
+
for k, value in d.items():
|
|
25
|
+
if key.startswith(k):
|
|
26
|
+
matches.append(k)
|
|
27
|
+
|
|
28
|
+
if len(matches) > 1:
|
|
29
|
+
raise Exception(f"There are multiple entries that are related to `{key}`: {matches}")
|
|
30
|
+
if len(matches) == 0:
|
|
31
|
+
raise Exception(f"No entries were found for {key}!")
|
|
32
|
+
|
|
33
|
+
return d[matches[0]]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def iter_params(text):
|
|
37
|
+
assert(isinstance(text, str))
|
|
38
|
+
beg = 0
|
|
39
|
+
while beg < len(text):
|
|
40
|
+
try:
|
|
41
|
+
pb = text.index('{', beg)
|
|
42
|
+
except ValueError:
|
|
43
|
+
break
|
|
44
|
+
pe = text.index('}', beg+1)
|
|
45
|
+
# Yield argument.
|
|
46
|
+
yield text[pb+1:pe]
|
|
47
|
+
beg = pe+1
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def format_model_name(name):
|
|
51
|
+
return name.replace("/", "_")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def parse_filepath(filepath, default_filepath=None, default_ext=None):
|
|
55
|
+
""" This is an auxiliary function for handling sources and targets from cmd string.
|
|
56
|
+
"""
|
|
57
|
+
if filepath is None:
|
|
58
|
+
return default_filepath, default_ext, None
|
|
59
|
+
info = filepath.split(":")
|
|
60
|
+
filepath = info[0]
|
|
61
|
+
meta = info[1] if len(info) > 1 else None
|
|
62
|
+
ext = filepath.split('.')[-1] if default_ext is None else default_ext
|
|
63
|
+
return filepath, ext, meta
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def handle_table_name(name):
|
|
67
|
+
return name.\
|
|
68
|
+
replace('-', '_').\
|
|
69
|
+
replace('.', "_")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def auto_import(name, is_class=False):
|
|
73
|
+
""" Import from the external python packages.
|
|
74
|
+
"""
|
|
75
|
+
def __get_module(comps_list):
|
|
76
|
+
return importlib.import_module(".".join(comps_list))
|
|
77
|
+
|
|
78
|
+
components = name.split('.')
|
|
79
|
+
m = getattr(__get_module(components[:-1]), components[-1])
|
|
80
|
+
|
|
81
|
+
return m() if is_class else m
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def dynamic_init(class_dir, class_filepath, class_name=None):
|
|
85
|
+
sys.path.append(class_dir)
|
|
86
|
+
class_path_list = class_filepath.split('/')
|
|
87
|
+
class_path_list[-1] = '.'.join(class_path_list[-1].split('.')[:-1])
|
|
88
|
+
class_name = class_path_list[-1].title() if class_name is None else class_name
|
|
89
|
+
class_path = ".".join(class_path_list + [class_name])
|
|
90
|
+
logger.info(f"Dynamic loading for the file and class `{class_path}`")
|
|
91
|
+
cls = auto_import(class_path, is_class=False)
|
|
92
|
+
return cls
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def optional_limit_iter(it_data, limit=None):
|
|
96
|
+
counter = Counter()
|
|
97
|
+
for data in it_data:
|
|
98
|
+
counter["returned"] += 1
|
|
99
|
+
if limit is not None and counter["returned"] > limit:
|
|
100
|
+
break
|
|
101
|
+
yield data
|
bulk_chain/infer.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from os.path import join, basename
|
|
9
|
+
|
|
10
|
+
from bulk_chain.core.llm_base import BaseLM
|
|
11
|
+
from bulk_chain.core.provider_sqlite import SQLiteProvider
|
|
12
|
+
from bulk_chain.core.service_args import CmdArgsService
|
|
13
|
+
from bulk_chain.core.service_csv import CsvService
|
|
14
|
+
from bulk_chain.core.service_data import DataService
|
|
15
|
+
from bulk_chain.core.service_json import JsonService
|
|
16
|
+
from bulk_chain.core.service_llm import chat_with_lm
|
|
17
|
+
from bulk_chain.core.service_schema import SchemaService
|
|
18
|
+
from bulk_chain.core.utils import dynamic_init, find_by_prefix, handle_table_name, optional_limit_iter, parse_filepath
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
logging.basicConfig(level=logging.INFO)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
CWD = os.getcwd()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def init_llm(**model_kwargs):
|
|
28
|
+
""" This method perform dynamic initialization of LLM from third-party resource.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# List of the Supported models and their API wrappers.
|
|
32
|
+
models_preset = {
|
|
33
|
+
"dynamic": lambda: dynamic_init(class_dir=CWD, class_filepath=llm_model_name,
|
|
34
|
+
class_name=llm_model_params)(**model_kwargs)
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
# Initialize LLM model.
|
|
38
|
+
params = args.adapter.split(':')
|
|
39
|
+
llm_model_type = params[0]
|
|
40
|
+
llm_model_name = params[1] if len(params) > 1 else params[-1]
|
|
41
|
+
llm_model_params = ':'.join(params[2:]) if len(params) > 2 else None
|
|
42
|
+
llm = find_by_prefix(d=models_preset, key=llm_model_type)()
|
|
43
|
+
|
|
44
|
+
return llm, llm_model_name
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def init_schema(json_filepath):
|
|
48
|
+
return SchemaService(json_data=JsonService.read_data(json_filepath))
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def iter_content(input_dicts_iter, llm, schema, cache_target, cache_table):
|
|
52
|
+
""" This method represent Python API aimed at application of `llm` towards
|
|
53
|
+
iterator of input_dicts via cache_target that refers to the SQLite using
|
|
54
|
+
the given `schema`
|
|
55
|
+
"""
|
|
56
|
+
assert (isinstance(llm, BaseLM))
|
|
57
|
+
assert (isinstance(schema, SchemaService))
|
|
58
|
+
assert (isinstance(cache_target, str))
|
|
59
|
+
assert (isinstance(cache_table, str))
|
|
60
|
+
|
|
61
|
+
infer_modes = {
|
|
62
|
+
"default": lambda prompt: llm.ask(prompt[:args.limit_prompt] if args.limit_prompt is not None else prompt)
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
def optional_update_data_records(c, data):
|
|
66
|
+
assert (isinstance(c, str))
|
|
67
|
+
|
|
68
|
+
if c in schema.p2r:
|
|
69
|
+
data[c] = DataService.get_prompt_text(prompt=data[c]["prompt"], data_dict=data)
|
|
70
|
+
if c in schema.r2p:
|
|
71
|
+
p_column = schema.r2p[c]
|
|
72
|
+
# This instruction takes a lot of time in a non-batching mode.
|
|
73
|
+
data[c] = infer_modes["default"](data[p_column])
|
|
74
|
+
|
|
75
|
+
return data[c]
|
|
76
|
+
|
|
77
|
+
cache_providers = {
|
|
78
|
+
"sqlite": lambda filepath, table_name, data_it: SQLiteProvider.write_auto(
|
|
79
|
+
data_it=data_it, target=filepath,
|
|
80
|
+
data2col_func=optional_update_data_records,
|
|
81
|
+
table_name=handle_table_name(table_name if table_name is not None else "contents"),
|
|
82
|
+
id_column_name="uid")
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
# We optionally wrap into limiter.
|
|
86
|
+
queries_it = optional_limit_iter(
|
|
87
|
+
it_data=map(lambda data: data.update(schema.cot_args) or data, input_dicts_iter),
|
|
88
|
+
limit=args.limit)
|
|
89
|
+
|
|
90
|
+
# Provide data caching.
|
|
91
|
+
cache_providers["sqlite"](cache_target, table_name=tgt_meta, data_it=tqdm(queries_it, desc="Iter content"))
|
|
92
|
+
|
|
93
|
+
return SQLiteProvider.read(cache_target, table=cache_table)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
if __name__ == '__main__':
|
|
97
|
+
|
|
98
|
+
parser = argparse.ArgumentParser(description="Infer Instruct LLM inference based on CoT schema")
|
|
99
|
+
parser.add_argument('--adapter', dest='adapter', type=str, default=None)
|
|
100
|
+
parser.add_argument('--src', dest='src', type=str, default=None)
|
|
101
|
+
parser.add_argument('--schema', dest='schema', type=str, default=None,
|
|
102
|
+
help="Path to the JSON file that describes schema")
|
|
103
|
+
parser.add_argument('--csv-sep', dest='csv_sep', type=str, default='\t')
|
|
104
|
+
parser.add_argument('--csv-escape-char', dest='csv_escape_char', type=str, default=None)
|
|
105
|
+
parser.add_argument('--to', dest='to', type=str, default=None, choices=["csv", "sqlite"])
|
|
106
|
+
parser.add_argument('--output', dest='output', type=str, default=None)
|
|
107
|
+
parser.add_argument('--limit', dest='limit', type=int, default=None,
|
|
108
|
+
help="Limit amount of source texts for prompting.")
|
|
109
|
+
parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
|
|
110
|
+
help="Optional trimming prompt by the specified amount of characters.")
|
|
111
|
+
|
|
112
|
+
native_args, model_args = CmdArgsService.partition_list(lst=sys.argv, sep="%%")
|
|
113
|
+
|
|
114
|
+
args = parser.parse_args(args=native_args[1:])
|
|
115
|
+
|
|
116
|
+
# Initialize Large Language Model.
|
|
117
|
+
llm, llm_model_name = init_llm(**CmdArgsService.args_to_dict(model_args))
|
|
118
|
+
|
|
119
|
+
# Setup schema.
|
|
120
|
+
schema = init_schema(args.schema)
|
|
121
|
+
if schema is not None:
|
|
122
|
+
logger.info(f"Using schema: {schema.name}")
|
|
123
|
+
|
|
124
|
+
input_providers = {
|
|
125
|
+
None: lambda _: chat_with_lm(llm, chain=schema.chain, model_name=llm_model_name),
|
|
126
|
+
"csv": lambda filepath: CsvService.read(target=filepath, row_id_key="uid", delimiter=args.csv_sep,
|
|
127
|
+
as_dict=True, skip_header=True, escapechar=args.csv_escape_char),
|
|
128
|
+
"jsonl": lambda filepath: JsonService.read_lines(src=filepath, row_id_key="uid")
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
output_providers = {
|
|
132
|
+
"csv": lambda filepath, data_it, header:
|
|
133
|
+
CsvService.write_handled(target=filepath, data_it=data_it, header=header, data2col_func=lambda v: list(v)),
|
|
134
|
+
"jsonl": lambda filepath, data_it, header:
|
|
135
|
+
JsonService.write_lines(target=filepath,
|
|
136
|
+
data_it=map(lambda item: {key:item[i] for i, key in enumerate(header)}, data_it))
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
# Setup output.
|
|
140
|
+
args.output = args.output.format(model=llm.name()) if args.output is not None else args.output
|
|
141
|
+
tgt_filepath, tgt_ext, tgt_meta = parse_filepath(args.output, default_ext=args.to)
|
|
142
|
+
|
|
143
|
+
# Input extension type defines the provider.
|
|
144
|
+
src_filepath, src_ext, src_meta = parse_filepath(args.src)
|
|
145
|
+
|
|
146
|
+
# Check whether we are in chat mode.
|
|
147
|
+
if src_ext is None:
|
|
148
|
+
input_providers[src_ext](None)
|
|
149
|
+
exit(0)
|
|
150
|
+
|
|
151
|
+
# Setup cache target as well as the related table.
|
|
152
|
+
cache_target = "".join(["_".join([join(CWD, basename(src_filepath)), llm.name(), schema.name]), f".sqlite"]) \
|
|
153
|
+
if tgt_filepath is None else tgt_filepath
|
|
154
|
+
cache_table = handle_table_name(tgt_meta if tgt_meta is not None else "contents")
|
|
155
|
+
|
|
156
|
+
data_it = iter_content(input_dicts_iter=input_providers[src_ext](src_filepath),
|
|
157
|
+
schema=schema,
|
|
158
|
+
llm=llm,
|
|
159
|
+
cache_target=cache_target,
|
|
160
|
+
cache_table=cache_table)
|
|
161
|
+
|
|
162
|
+
# Setup output target
|
|
163
|
+
tgt_ext = src_ext if tgt_ext is None else tgt_ext
|
|
164
|
+
output_target = "".join(["_".join([join(CWD, basename(src_filepath)), llm.name(), schema.name]), f".{tgt_ext}"]) \
|
|
165
|
+
if tgt_filepath is None else tgt_filepath
|
|
166
|
+
|
|
167
|
+
# Perform output writing process.
|
|
168
|
+
output_providers[tgt_ext](filepath=output_target,
|
|
169
|
+
data_it=data_it,
|
|
170
|
+
header=SQLiteProvider.get_columns(target=cache_target, table=cache_table))
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Nicolay Rusnachenko
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: bulk_chain
|
|
3
|
+
Version: 0.24.0
|
|
4
|
+
Summary: A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, ensuring reliable results for bulk input requests.
|
|
5
|
+
Home-page: https://github.com/nicolay-r/bulk-chain
|
|
6
|
+
Author: Nicolay Rusnachenko
|
|
7
|
+
Author-email: rusnicolay@gmail.com
|
|
8
|
+
License: MIT License
|
|
9
|
+
Keywords: natural language processing,chain-of-thought,reasoning
|
|
10
|
+
Classifier: Programming Language :: Python
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
|
14
|
+
Classifier: Topic :: Text Processing :: Linguistic
|
|
15
|
+
Requires-Python: >=3.6
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: tqdm
|
|
19
|
+
|
|
20
|
+
# bulk-chain
|
|
21
|
+

|
|
22
|
+
[](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
|
|
23
|
+
|
|
24
|
+
A lightweight, no-strings-attached **[Chain-of-Thought](https://arxiv.org/abs/2201.11903) framework** for your LLM, ensuring reliable results for bulk input requests stored in `CSV` / `JSONL` / `sqlite`.
|
|
25
|
+
It allows applying series of prompts formed into `schema` (See [related section](#chain-of-thought-schema))
|
|
26
|
+
|
|
27
|
+
### Features
|
|
28
|
+
* ✅ **No-strings**: you're free to LLM dependencies and flexible `venv` customization.
|
|
29
|
+
* ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
|
|
30
|
+
* ✅ **Progress caching**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
|
|
31
|
+
* ✅ **Support schemas descriptions** for Chain-of-Thought concept.
|
|
32
|
+
|
|
33
|
+
# Installation
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
pip install git+https://github.com/nicolay-r/bulk-chain@master
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
## Chain-of-Thought Schema
|
|
40
|
+
|
|
41
|
+
To declare Chain-of-Though (CoT) schema, this project exploits `JSON` format.
|
|
42
|
+
This format adopts `name` field for declaring a name and `schema` is a list of CoT instructions for the Large Language Model.
|
|
43
|
+
|
|
44
|
+
Each step represents a dictionary with `prompt` and `out` keys that corresponds to the input prompt and output variable name respectively.
|
|
45
|
+
All the variable names are expected to be mentioned in `{}`.
|
|
46
|
+
|
|
47
|
+
Below, is an example on how to declare your own schema:
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
{
|
|
51
|
+
"name": "schema-name",
|
|
52
|
+
"schema": [
|
|
53
|
+
{"prompt": "Given the question '{text}', let's think step-by-step.",
|
|
54
|
+
"out": "steps"},
|
|
55
|
+
{"prompt": "For the question '{text}' the reasoining steps are '{steps}'. what would be an answer?",
|
|
56
|
+
"out": "answer"},
|
|
57
|
+
]
|
|
58
|
+
}
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
Another templates are available [here](/ext/schema/thor_cot_schema.json).
|
|
62
|
+
|
|
63
|
+
# Usage
|
|
64
|
+
|
|
65
|
+
Just **three** simple steps:
|
|
66
|
+
|
|
67
|
+
1. Define your [CoT Schema](#chain-of-thought-schema), or fetch it as shown below:
|
|
68
|
+
```bash
|
|
69
|
+
!wget https://raw.githubusercontent.com/nicolay-r/bulk-chain/refs/heads/master/ext/schema/default.json
|
|
70
|
+
```
|
|
71
|
+
2. Fetch or write your own **model** or pick the one [preset here](/ext/):
|
|
72
|
+
```bash
|
|
73
|
+
!wget https://raw.githubusercontent.com/nicolay-r/bulk-chain/refs/heads/master/ext/flan_t5.py
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
3. Launch inference in (chat mode):
|
|
77
|
+
```bash
|
|
78
|
+
!python -m bulk_chain.infer \
|
|
79
|
+
--schema "default.json" \
|
|
80
|
+
--adapter "dynamic:flan_t5.py:FlanT5" \
|
|
81
|
+
%% \
|
|
82
|
+
--device "cpu" \
|
|
83
|
+
--temp 0.1
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
# Embed your LLM
|
|
87
|
+
|
|
88
|
+
All you have to do is to implement `BaseLM` class, that includes:
|
|
89
|
+
* `__init__` -- for initialization;
|
|
90
|
+
* `ask(prompt)` -- infer your model with the given `prompt`.
|
|
91
|
+
|
|
92
|
+
See examples with models [here](/ext).
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
bulk_chain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
bulk_chain/infer.py,sha256=HXFcl_7u5sgybDv_v5_up-Mpe-zSX0vtgsG1Wh1h-UA,7184
|
|
3
|
+
bulk_chain/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
+
bulk_chain/core/llm_base.py,sha256=5js2RJLpNS5t-De-xTpZCbLMgbz3F_b9tU_CtXhy02I,259
|
|
5
|
+
bulk_chain/core/provider_sqlite.py,sha256=D7axdeTDvv-ULHKTalFWbeKC3WaYOLI7lVrXFAXkct8,3213
|
|
6
|
+
bulk_chain/core/service_args.py,sha256=Qr3rHsAB8wnajB-DbU-GjiEpRZFP4D6s1lVTpLkPPX4,1294
|
|
7
|
+
bulk_chain/core/service_csv.py,sha256=-m8tNN9aIqRfJa4sPUX8ZUDP4W0fgnnOR3_0PapepDY,1984
|
|
8
|
+
bulk_chain/core/service_data.py,sha256=18gQwSCTEsI7XFukq8AE5lDJX_QQRpasaH69g6EddV0,797
|
|
9
|
+
bulk_chain/core/service_json.py,sha256=alYqTQbBjAcCh7anSTOZs1CLJbiWrLPpzLcoADstD0Q,743
|
|
10
|
+
bulk_chain/core/service_llm.py,sha256=tYgMphJkXunhxdrThdfI4eM8qQTCZfEM1kabbReVjuQ,2726
|
|
11
|
+
bulk_chain/core/service_schema.py,sha256=JVhOv2YP5VEtiwOq_zgCzhS2uF_BOATAgg6fmKRf2NQ,1209
|
|
12
|
+
bulk_chain/core/utils.py,sha256=UV6Cefaw7yZiYblsCr-s9LsbcI83xe7eESBvha9A2Og,2784
|
|
13
|
+
bulk_chain-0.24.0.dist-info/LICENSE,sha256=VF9SjNpwwSSFEY_eP_8A1ocDCrbwfjI1pZexXdCkOwo,1076
|
|
14
|
+
bulk_chain-0.24.0.dist-info/METADATA,sha256=l_RpSlOGQzuA0buVn7I54XN_c9Fn_5Y6lhNPkqlhYqo,3496
|
|
15
|
+
bulk_chain-0.24.0.dist-info/WHEEL,sha256=pL8R0wFFS65tNSRnaOVrsw9EOkOqxLrlUPenUYnJKNo,91
|
|
16
|
+
bulk_chain-0.24.0.dist-info/top_level.txt,sha256=Hxq_wyH-GDXKBaA63UfBIiMJO2eCHJG5EOrXDphpeB4,11
|
|
17
|
+
bulk_chain-0.24.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
bulk_chain
|