bulk-chain 0.24.2__tar.gz → 0.25.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/PKG-INFO +37 -26
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/README.md +36 -23
- bulk_chain-0.25.0/bulk_chain/api.py +79 -0
- bulk_chain-0.25.0/bulk_chain/core/llm_base.py +52 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain/core/service_args.py +25 -6
- bulk_chain-0.25.0/bulk_chain/core/service_batch.py +51 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain/core/service_data.py +4 -0
- bulk_chain-0.25.0/bulk_chain/core/service_dict.py +10 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain/core/service_llm.py +3 -3
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain/core/service_schema.py +1 -2
- bulk_chain-0.25.0/bulk_chain/infer.py +191 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain.egg-info/PKG-INFO +37 -26
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain.egg-info/SOURCES.txt +4 -1
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/setup.py +1 -1
- bulk_chain-0.25.0/test/test_api.py +42 -0
- bulk_chain-0.25.0/test/test_cmdargs.py +20 -0
- bulk_chain-0.24.2/bulk_chain/core/llm_base.py +0 -35
- bulk_chain-0.24.2/bulk_chain/infer.py +0 -176
- bulk_chain-0.24.2/bulk_chain.egg-info/requires.txt +0 -2
- bulk_chain-0.24.2/test/test_cmdargs.py +0 -9
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/LICENSE +0 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain/__init__.py +0 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain/core/__init__.py +0 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain/core/service_json.py +0 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain/core/utils.py +0 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain.egg-info/dependency_links.txt +0 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/bulk_chain.egg-info/top_level.txt +0 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/setup.cfg +0 -0
- {bulk_chain-0.24.2 → bulk_chain-0.25.0}/test/test_args_seeking.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: bulk_chain
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.25.0
|
|
4
4
|
Summary: A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, ensuring reliable results for bulk input requests.
|
|
5
5
|
Home-page: https://github.com/nicolay-r/bulk-chain
|
|
6
6
|
Author: Nicolay Rusnachenko
|
|
@@ -15,33 +15,42 @@ Classifier: Topic :: Text Processing :: Linguistic
|
|
|
15
15
|
Requires-Python: >=3.6
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
17
|
License-File: LICENSE
|
|
18
|
-
Requires-Dist: tqdm
|
|
19
|
-
Requires-Dist: source-iter==0.24.2
|
|
20
18
|
|
|
21
|
-
# bulk-chain 0.
|
|
19
|
+
# bulk-chain 0.25.0
|
|
22
20
|

|
|
23
21
|
[](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
|
|
24
22
|
[](https://x.com/nicolayr_/status/1847969224636961033)
|
|
23
|
+
[](https://pypistats.org/packages/bulk-chain)
|
|
25
24
|
|
|
26
25
|
<p align="center">
|
|
27
26
|
<img src="logo.png"/>
|
|
28
27
|
</p>
|
|
29
28
|
|
|
30
|
-
A lightweight, no-strings-attached **[Chain-of-Thought](https://arxiv.org/abs/2201.11903)
|
|
31
|
-
It allows applying series of prompts formed into `schema` (See [related section](#chain-of-thought-schema))
|
|
29
|
+
A lightweight, no-strings-attached **framework** for your LLM that allows applying [Chain-of-Thought](https://arxiv.org/abs/2201.11903) prompt `schema` (See [related section](#chain-of-thought-schema)) towards a massive textual collections.
|
|
32
30
|
|
|
33
|
-
### Features
|
|
31
|
+
### Main Features
|
|
34
32
|
* ✅ **No-strings**: you're free to LLM dependencies and flexible `venv` customization.
|
|
35
|
-
* ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
|
|
36
|
-
* ✅ **Progress caching**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
|
|
37
33
|
* ✅ **Support schemas descriptions** for Chain-of-Thought concept.
|
|
34
|
+
* ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
|
|
35
|
+
|
|
36
|
+
### Extra Features
|
|
37
|
+
* ✅ **Progress caching [for remote LLMs]**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
|
|
38
|
+
|
|
38
39
|
|
|
39
40
|
# Installation
|
|
40
41
|
|
|
42
|
+
From PyPI:
|
|
43
|
+
|
|
41
44
|
```bash
|
|
42
45
|
pip install bulk-chain
|
|
43
46
|
```
|
|
44
47
|
|
|
48
|
+
or latest version from here:
|
|
49
|
+
|
|
50
|
+
```bash
|
|
51
|
+
pip install git+https://github.com/nicolay-r/bulk-chain@master
|
|
52
|
+
```
|
|
53
|
+
|
|
45
54
|
## Chain-of-Thought Schema
|
|
46
55
|
|
|
47
56
|
To declare Chain-of-Though (CoT) schema, this project exploits `JSON` format.
|
|
@@ -64,35 +73,37 @@ Below, is an example on how to declare your own schema:
|
|
|
64
73
|
}
|
|
65
74
|
```
|
|
66
75
|
|
|
67
|
-
Another templates are available [here](/ext/schema/
|
|
76
|
+
Another templates are available [here](/ext/schema/).
|
|
68
77
|
|
|
69
78
|
# Usage
|
|
70
79
|
|
|
71
|
-
|
|
80
|
+
Preliminary steps:
|
|
72
81
|
|
|
73
|
-
1. Define your [
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
82
|
+
1. Define your [schema](#chain-of-thought-schema) ([Example for Sentiment Analysis](/ext/schema/thor_cot_schema.json)))
|
|
83
|
+
2. Wrap or pick **LLM model** from the [list of presets](/ext/).
|
|
84
|
+
|
|
85
|
+
## API
|
|
86
|
+
|
|
87
|
+
Please take a look at the [**related Wiki page**](https://github.com/nicolay-r/bulk-chain/wiki)
|
|
88
|
+
|
|
89
|
+
## Shell
|
|
90
|
+
|
|
91
|
+
> **NOTE:** You have to install `source-iter` package
|
|
81
92
|
|
|
82
|
-
3. Launch inference in (chat mode):
|
|
83
93
|
```bash
|
|
84
|
-
|
|
85
|
-
--
|
|
86
|
-
--
|
|
87
|
-
|
|
88
|
-
|
|
94
|
+
python3 -m bulk_chain.infer \
|
|
95
|
+
--src "<PATH-TO-YOUR-CSV-or-JSONL>" \
|
|
96
|
+
--schema "ext/schema/default.json" \
|
|
97
|
+
--adapter "dynamic:ext/replicate.py:Replicate" \
|
|
98
|
+
%%m \
|
|
99
|
+
--api_token "<REPLICATE-API-TOKEN>" \
|
|
89
100
|
--temp 0.1
|
|
90
101
|
```
|
|
91
102
|
|
|
92
103
|
# Embed your LLM
|
|
93
104
|
|
|
94
105
|
All you have to do is to implement `BaseLM` class, that includes:
|
|
95
|
-
* `__init__` -- for
|
|
106
|
+
* `__init__` -- for setting up *batching mode support* and (optional) *model name*;
|
|
96
107
|
* `ask(prompt)` -- infer your model with the given `prompt`.
|
|
97
108
|
|
|
98
109
|
See examples with models [here](/ext).
|
|
@@ -1,27 +1,38 @@
|
|
|
1
|
-
# bulk-chain 0.
|
|
1
|
+
# bulk-chain 0.25.0
|
|
2
2
|

|
|
3
3
|
[](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
|
|
4
4
|
[](https://x.com/nicolayr_/status/1847969224636961033)
|
|
5
|
+
[](https://pypistats.org/packages/bulk-chain)
|
|
5
6
|
|
|
6
7
|
<p align="center">
|
|
7
8
|
<img src="logo.png"/>
|
|
8
9
|
</p>
|
|
9
10
|
|
|
10
|
-
A lightweight, no-strings-attached **[Chain-of-Thought](https://arxiv.org/abs/2201.11903)
|
|
11
|
-
It allows applying series of prompts formed into `schema` (See [related section](#chain-of-thought-schema))
|
|
11
|
+
A lightweight, no-strings-attached **framework** for your LLM that allows applying [Chain-of-Thought](https://arxiv.org/abs/2201.11903) prompt `schema` (See [related section](#chain-of-thought-schema)) towards a massive textual collections.
|
|
12
12
|
|
|
13
|
-
### Features
|
|
13
|
+
### Main Features
|
|
14
14
|
* ✅ **No-strings**: you're free to LLM dependencies and flexible `venv` customization.
|
|
15
|
-
* ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
|
|
16
|
-
* ✅ **Progress caching**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
|
|
17
15
|
* ✅ **Support schemas descriptions** for Chain-of-Thought concept.
|
|
16
|
+
* ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
|
|
17
|
+
|
|
18
|
+
### Extra Features
|
|
19
|
+
* ✅ **Progress caching [for remote LLMs]**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
|
|
20
|
+
|
|
18
21
|
|
|
19
22
|
# Installation
|
|
20
23
|
|
|
24
|
+
From PyPI:
|
|
25
|
+
|
|
21
26
|
```bash
|
|
22
27
|
pip install bulk-chain
|
|
23
28
|
```
|
|
24
29
|
|
|
30
|
+
or latest version from here:
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
pip install git+https://github.com/nicolay-r/bulk-chain@master
|
|
34
|
+
```
|
|
35
|
+
|
|
25
36
|
## Chain-of-Thought Schema
|
|
26
37
|
|
|
27
38
|
To declare Chain-of-Though (CoT) schema, this project exploits `JSON` format.
|
|
@@ -44,35 +55,37 @@ Below, is an example on how to declare your own schema:
|
|
|
44
55
|
}
|
|
45
56
|
```
|
|
46
57
|
|
|
47
|
-
Another templates are available [here](/ext/schema/
|
|
58
|
+
Another templates are available [here](/ext/schema/).
|
|
48
59
|
|
|
49
60
|
# Usage
|
|
50
61
|
|
|
51
|
-
|
|
62
|
+
Preliminary steps:
|
|
52
63
|
|
|
53
|
-
1. Define your [
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
64
|
+
1. Define your [schema](#chain-of-thought-schema) ([Example for Sentiment Analysis](/ext/schema/thor_cot_schema.json)))
|
|
65
|
+
2. Wrap or pick **LLM model** from the [list of presets](/ext/).
|
|
66
|
+
|
|
67
|
+
## API
|
|
68
|
+
|
|
69
|
+
Please take a look at the [**related Wiki page**](https://github.com/nicolay-r/bulk-chain/wiki)
|
|
70
|
+
|
|
71
|
+
## Shell
|
|
72
|
+
|
|
73
|
+
> **NOTE:** You have to install `source-iter` package
|
|
61
74
|
|
|
62
|
-
3. Launch inference in (chat mode):
|
|
63
75
|
```bash
|
|
64
|
-
|
|
65
|
-
--
|
|
66
|
-
--
|
|
67
|
-
|
|
68
|
-
|
|
76
|
+
python3 -m bulk_chain.infer \
|
|
77
|
+
--src "<PATH-TO-YOUR-CSV-or-JSONL>" \
|
|
78
|
+
--schema "ext/schema/default.json" \
|
|
79
|
+
--adapter "dynamic:ext/replicate.py:Replicate" \
|
|
80
|
+
%%m \
|
|
81
|
+
--api_token "<REPLICATE-API-TOKEN>" \
|
|
69
82
|
--temp 0.1
|
|
70
83
|
```
|
|
71
84
|
|
|
72
85
|
# Embed your LLM
|
|
73
86
|
|
|
74
87
|
All you have to do is to implement `BaseLM` class, that includes:
|
|
75
|
-
* `__init__` -- for
|
|
88
|
+
* `__init__` -- for setting up *batching mode support* and (optional) *model name*;
|
|
76
89
|
* `ask(prompt)` -- infer your model with the given `prompt`.
|
|
77
90
|
|
|
78
91
|
See examples with models [here](/ext).
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from itertools import chain
|
|
3
|
+
|
|
4
|
+
from bulk_chain.core.llm_base import BaseLM
|
|
5
|
+
from bulk_chain.core.service_batch import BatchIterator, BatchService
|
|
6
|
+
from bulk_chain.core.service_data import DataService
|
|
7
|
+
from bulk_chain.core.service_dict import DictionaryService
|
|
8
|
+
from bulk_chain.core.service_json import JsonService
|
|
9
|
+
from bulk_chain.core.service_schema import SchemaService
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
INFER_MODES = {
|
|
13
|
+
"default": lambda llm, prompt, limit_prompt=None: llm.ask_core(
|
|
14
|
+
prompt[:limit_prompt] if limit_prompt is not None else prompt),
|
|
15
|
+
"batch": lambda llm, batch, limit_prompt=None: llm.ask_core(
|
|
16
|
+
DataService.limit_prompts(batch, limit=limit_prompt))
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
CWD = os.getcwd()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _update_batch_content(c, batch, schema, infer_func):
|
|
24
|
+
assert (isinstance(batch, list))
|
|
25
|
+
assert (isinstance(c, str))
|
|
26
|
+
|
|
27
|
+
if c in schema.p2r:
|
|
28
|
+
for batch_item in batch:
|
|
29
|
+
batch_item[c] = DataService.get_prompt_text(prompt=batch_item[c]["prompt"], data_dict=batch_item)
|
|
30
|
+
if c in schema.r2p:
|
|
31
|
+
p_column = schema.r2p[c]
|
|
32
|
+
# This instruction takes a lot of time in a non-batching mode.
|
|
33
|
+
BatchService.handle_param_as_batch(batch=batch,
|
|
34
|
+
src_param=p_column,
|
|
35
|
+
tgt_param=c,
|
|
36
|
+
handle_func=lambda b: infer_func(b))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _infer_batch(batch, schema, infer_func, cols=None):
|
|
40
|
+
assert (isinstance(batch, list))
|
|
41
|
+
assert (callable(infer_func))
|
|
42
|
+
|
|
43
|
+
if len(batch) == 0:
|
|
44
|
+
return batch
|
|
45
|
+
|
|
46
|
+
if cols is None:
|
|
47
|
+
first_item = batch[0]
|
|
48
|
+
cols = first_item.keys() if cols is None else cols
|
|
49
|
+
|
|
50
|
+
for c in cols:
|
|
51
|
+
_update_batch_content(c=c, batch=batch, schema=schema, infer_func=infer_func)
|
|
52
|
+
|
|
53
|
+
return batch
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, limit_prompt=None):
|
|
57
|
+
""" This method represent Python API aimed at application of `llm` towards
|
|
58
|
+
iterator of input_dicts via cache_target that refers to the SQLite using
|
|
59
|
+
the given `schema`
|
|
60
|
+
"""
|
|
61
|
+
assert (isinstance(llm, BaseLM))
|
|
62
|
+
|
|
63
|
+
# Quick initialization of the schema.
|
|
64
|
+
if isinstance(schema, str):
|
|
65
|
+
schema = JsonService.read(schema)
|
|
66
|
+
if isinstance(schema, dict):
|
|
67
|
+
schema = SchemaService(json_data=schema)
|
|
68
|
+
|
|
69
|
+
prompts_it = map(
|
|
70
|
+
lambda data: DictionaryService.custom_update(src_dict=data, other_dict=schema.cot_args),
|
|
71
|
+
input_dicts_it
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
content_it = (_infer_batch(batch=batch,
|
|
75
|
+
infer_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
|
|
76
|
+
schema=schema)
|
|
77
|
+
for batch in BatchIterator(prompts_it, batch_size=batch_size))
|
|
78
|
+
|
|
79
|
+
yield from content_it if return_batch else chain.from_iterable(content_it)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
from bulk_chain.core.utils import format_model_name
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseLM(object):
|
|
8
|
+
|
|
9
|
+
def __init__(self, name=None, attempts=None, delay_sec=1, enable_log=True,
|
|
10
|
+
support_batching=False, **kwargs):
|
|
11
|
+
|
|
12
|
+
self.__name = name
|
|
13
|
+
self.__attempts = 1 if attempts is None else attempts
|
|
14
|
+
self.__delay_sec = delay_sec
|
|
15
|
+
self.__support_batching = support_batching
|
|
16
|
+
|
|
17
|
+
if enable_log:
|
|
18
|
+
self.__logger = logging.getLogger(__name__)
|
|
19
|
+
logging.basicConfig(level=logging.INFO)
|
|
20
|
+
|
|
21
|
+
def ask_core(self, batch):
|
|
22
|
+
|
|
23
|
+
for i in range(self.__attempts):
|
|
24
|
+
try:
|
|
25
|
+
if self.__support_batching:
|
|
26
|
+
# Launch in batch mode.
|
|
27
|
+
content = self.ask(batch)
|
|
28
|
+
else:
|
|
29
|
+
# Launch in non-batch mode.
|
|
30
|
+
assert len(batch) == 1, "The LM does not support batching," \
|
|
31
|
+
f" while size of the content is {len(batch)} which is not equal 1. " \
|
|
32
|
+
f"Please enable batch-supporting or set required inference settings."
|
|
33
|
+
content = batch[0]
|
|
34
|
+
|
|
35
|
+
response = self.ask(content)
|
|
36
|
+
|
|
37
|
+
# Wrapping into batch the response in the case of non-batching mode.
|
|
38
|
+
return response if self.__support_batching else [response]
|
|
39
|
+
|
|
40
|
+
except Exception as e:
|
|
41
|
+
if self.__logger is not None:
|
|
42
|
+
self.__logger.info("Unable to infer the result. Try {} out of {}.".format(i, self.__attempts))
|
|
43
|
+
self.__logger.info(e)
|
|
44
|
+
time.sleep(self.__delay_sec)
|
|
45
|
+
|
|
46
|
+
raise Exception("Can't infer")
|
|
47
|
+
|
|
48
|
+
def ask(self, content):
|
|
49
|
+
raise NotImplemented()
|
|
50
|
+
|
|
51
|
+
def name(self):
|
|
52
|
+
return format_model_name(self.__name)
|
|
@@ -33,14 +33,33 @@ class CmdArgsService:
|
|
|
33
33
|
yield __release()
|
|
34
34
|
|
|
35
35
|
@staticmethod
|
|
36
|
-
def
|
|
36
|
+
def __find_suffix_ind(lst, idx_from, end_prefix):
|
|
37
|
+
for i in range(idx_from, len(lst)):
|
|
38
|
+
if lst[i].startswith(end_prefix):
|
|
39
|
+
return i
|
|
40
|
+
return len(lst)
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def extract_native_args(lst, end_prefix):
|
|
44
|
+
return lst[:CmdArgsService.__find_suffix_ind(lst, idx_from=0, end_prefix=end_prefix)]
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def find_grouped_args(lst, starts_with, end_prefix):
|
|
37
48
|
"""Slices a list in two, cutting on index matching "sep"
|
|
38
49
|
"""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
return
|
|
50
|
+
|
|
51
|
+
# Checking the presence of starts_with.
|
|
52
|
+
# We have to return empty content in the case of absence starts_with in the lst.
|
|
53
|
+
if starts_with not in lst:
|
|
54
|
+
return []
|
|
55
|
+
|
|
56
|
+
# Assigning start index.
|
|
57
|
+
idx_from = lst.index(starts_with) + 1
|
|
58
|
+
|
|
59
|
+
# Assigning end index.
|
|
60
|
+
idx_to = CmdArgsService.__find_suffix_ind(lst, idx_from=idx_from, end_prefix=end_prefix)
|
|
61
|
+
|
|
62
|
+
return lst[idx_from:idx_to]
|
|
44
63
|
|
|
45
64
|
@staticmethod
|
|
46
65
|
def args_to_dict(args):
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
class BatchService(object):
|
|
2
|
+
|
|
3
|
+
@staticmethod
|
|
4
|
+
def handle_param_as_batch(batch, src_param, tgt_param, handle_func):
|
|
5
|
+
assert (isinstance(batch, list))
|
|
6
|
+
assert (isinstance(src_param, str))
|
|
7
|
+
assert (callable(handle_func))
|
|
8
|
+
|
|
9
|
+
_batch = [item[src_param] for item in batch]
|
|
10
|
+
|
|
11
|
+
# Do handling for the batch.
|
|
12
|
+
_handled_batch = handle_func(_batch)
|
|
13
|
+
assert (isinstance(_handled_batch, list))
|
|
14
|
+
|
|
15
|
+
# Apply changes.
|
|
16
|
+
for i, item in enumerate(batch):
|
|
17
|
+
item[tgt_param] = _handled_batch[i]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BatchIterator:
|
|
21
|
+
|
|
22
|
+
def __init__(self, data_iter, batch_size, end_value=None):
|
|
23
|
+
assert(isinstance(batch_size, int) and batch_size > 0)
|
|
24
|
+
assert(callable(end_value) or end_value is None)
|
|
25
|
+
self.__data_iter = data_iter
|
|
26
|
+
self.__index = 0
|
|
27
|
+
self.__batch_size = batch_size
|
|
28
|
+
self.__end_value = end_value
|
|
29
|
+
|
|
30
|
+
def __iter__(self):
|
|
31
|
+
return self
|
|
32
|
+
|
|
33
|
+
def __next__(self):
|
|
34
|
+
buffer = []
|
|
35
|
+
while True:
|
|
36
|
+
try:
|
|
37
|
+
data = next(self.__data_iter)
|
|
38
|
+
except StopIteration:
|
|
39
|
+
break
|
|
40
|
+
buffer.append(data)
|
|
41
|
+
if len(buffer) == self.__batch_size:
|
|
42
|
+
break
|
|
43
|
+
|
|
44
|
+
if len(buffer) > 0:
|
|
45
|
+
self.__index += 1
|
|
46
|
+
return buffer
|
|
47
|
+
|
|
48
|
+
if self.__end_value is None:
|
|
49
|
+
raise StopIteration
|
|
50
|
+
else:
|
|
51
|
+
return self.__end_value()
|
|
@@ -20,3 +20,7 @@ class DataService(object):
|
|
|
20
20
|
field_names = list(parse_fields_func(prompt))
|
|
21
21
|
return DataService.compose_prompt_text(
|
|
22
22
|
prompt=prompt, data_dict=data_dict, field_names=field_names)
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def limit_prompts(prompts_list, limit=None):
|
|
26
|
+
return [p[:limit] if limit is not None else p for p in prompts_list]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
class DictionaryService:
|
|
2
|
+
|
|
3
|
+
@staticmethod
|
|
4
|
+
def custom_update(src_dict, other_dict):
|
|
5
|
+
for k, v in other_dict.items():
|
|
6
|
+
if k in src_dict:
|
|
7
|
+
raise Exception(f"The key `{k}` is already defined in both dicts with values: "
|
|
8
|
+
f"`{src_dict[k]}` (src) and `{v}` (other)")
|
|
9
|
+
src_dict[k] = v
|
|
10
|
+
return src_dict
|
|
@@ -74,9 +74,9 @@ def chat_with_lm(lm, chain=None, model_name=None):
|
|
|
74
74
|
logger.info(nice_output(actual_prompt, pad=pad*2, remove_new_line=True, width=80))
|
|
75
75
|
|
|
76
76
|
# Response.
|
|
77
|
-
|
|
77
|
+
response_batch = lm.ask_core(batch=[actual_prompt])
|
|
78
78
|
logger.info(pad_str(f"{model_name} (resp)->", pad=pad))
|
|
79
|
-
logger.info(nice_output(
|
|
79
|
+
logger.info(nice_output(response_batch[0], pad=pad * 2, remove_new_line=False, width=80))
|
|
80
80
|
|
|
81
81
|
# Collecting the answer for the next turn.
|
|
82
|
-
data_dict[prompt_args["out"]] =
|
|
82
|
+
data_dict[prompt_args["out"]] = response_batch[0]
|
|
@@ -2,12 +2,11 @@ class SchemaService(object):
|
|
|
2
2
|
|
|
3
3
|
def __init__(self, json_data):
|
|
4
4
|
self.src = json_data
|
|
5
|
-
self.name = self.src["name"]
|
|
6
5
|
self.r2p, self.p2r, self.cot_args, self.chain = SchemaService.__init_schema(prompts=json_data["schema"])
|
|
7
6
|
|
|
8
7
|
@classmethod
|
|
9
8
|
def from_prompt(cls, prompt):
|
|
10
|
-
prompt_schema = {"
|
|
9
|
+
prompt_schema = {"schema": [{"prompt": prompt, "out": "response", "in": "prompt"}]}
|
|
11
10
|
return cls(prompt_schema)
|
|
12
11
|
|
|
13
12
|
@staticmethod
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
from os.path import join, basename
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
from source_iter.service_csv import CsvService
|
|
8
|
+
from source_iter.service_jsonl import JsonlService
|
|
9
|
+
from source_iter.service_sqlite import SQLite3Service
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
from bulk_chain.api import INFER_MODES, _infer_batch, CWD
|
|
13
|
+
from bulk_chain.core.llm_base import BaseLM
|
|
14
|
+
from bulk_chain.core.service_args import CmdArgsService
|
|
15
|
+
from bulk_chain.core.service_dict import DictionaryService
|
|
16
|
+
from bulk_chain.core.service_json import JsonService
|
|
17
|
+
from bulk_chain.core.service_llm import chat_with_lm
|
|
18
|
+
from bulk_chain.core.service_schema import SchemaService
|
|
19
|
+
from bulk_chain.core.utils import dynamic_init, find_by_prefix, handle_table_name, optional_limit_iter, parse_filepath
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
logging.basicConfig(level=logging.INFO)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
WRITER_PROVIDERS = {
|
|
26
|
+
"sqlite": lambda filepath, table_name, data_it, infer_data_func, **kwargs: SQLite3Service.write(
|
|
27
|
+
data_it=data_it, target=filepath, table_name=table_name, data2col_func=infer_data_func,
|
|
28
|
+
skip_existed=True, **kwargs)
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
READER_PROVIDERS = {
|
|
33
|
+
"sqlite": lambda filepath, table_name: SQLite3Service.read(filepath, table=table_name)
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def init_llm(**model_kwargs):
|
|
38
|
+
""" This method perform dynamic initialization of LLM from third-party resource.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
# List of the Supported models and their API wrappers.
|
|
42
|
+
models_preset = {
|
|
43
|
+
"dynamic": lambda: dynamic_init(class_dir=CWD, class_filepath=llm_model_name,
|
|
44
|
+
class_name=llm_model_params)(**model_kwargs)
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Initialize LLM model.
|
|
48
|
+
params = args.adapter.split(':')
|
|
49
|
+
llm_model_type = params[0]
|
|
50
|
+
llm_model_name = params[1] if len(params) > 1 else params[-1]
|
|
51
|
+
llm_model_params = ':'.join(params[2:]) if len(params) > 2 else None
|
|
52
|
+
llm = find_by_prefix(d=models_preset, key=llm_model_type)()
|
|
53
|
+
|
|
54
|
+
return llm, llm_model_name
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def iter_content_cached(input_dicts_it, llm, schema, cache_target, limit_prompt=None, **cache_kwargs):
|
|
58
|
+
assert (isinstance(llm, BaseLM))
|
|
59
|
+
assert (isinstance(cache_target, str))
|
|
60
|
+
|
|
61
|
+
# Quick initialization of the schema.
|
|
62
|
+
if isinstance(schema, str):
|
|
63
|
+
schema = JsonService.read(schema)
|
|
64
|
+
if isinstance(schema, dict):
|
|
65
|
+
schema = SchemaService(json_data=schema)
|
|
66
|
+
|
|
67
|
+
# Iterator of the queries.
|
|
68
|
+
prompts_it = map(
|
|
69
|
+
lambda data: DictionaryService.custom_update(src_dict=data, other_dict=schema.cot_args),
|
|
70
|
+
input_dicts_it
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Parse target.
|
|
74
|
+
cache_filepath, _, cache_table = parse_filepath(filepath=cache_target)
|
|
75
|
+
|
|
76
|
+
# Perform caching first.
|
|
77
|
+
WRITER_PROVIDERS["sqlite"](
|
|
78
|
+
filepath=cache_filepath, table_name=cache_table,
|
|
79
|
+
data_it=tqdm(prompts_it, desc="Iter content"),
|
|
80
|
+
infer_data_func=lambda c, prompt: _infer_batch(
|
|
81
|
+
batch=[prompt], cols=[c],
|
|
82
|
+
infer_func=lambda batch: INFER_MODES["default"](llm, batch, limit_prompt),
|
|
83
|
+
schema=schema)[0][c],
|
|
84
|
+
**cache_kwargs)
|
|
85
|
+
|
|
86
|
+
# Then retrieve data.
|
|
87
|
+
return READER_PROVIDERS["sqlite"](filepath=cache_filepath, table_name=cache_table)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
if __name__ == '__main__':
|
|
91
|
+
|
|
92
|
+
parser = argparse.ArgumentParser(description="Infer Instruct LLM inference based on CoT schema")
|
|
93
|
+
parser.add_argument('--adapter', dest='adapter', type=str, default=None)
|
|
94
|
+
parser.add_argument('--attempts', dest='attempts', type=int, default=None)
|
|
95
|
+
parser.add_argument('--id-col', dest='id_col', type=str, default="uid")
|
|
96
|
+
parser.add_argument('--src', dest='src', type=str, default=None)
|
|
97
|
+
parser.add_argument('--schema', dest='schema', type=str, default=None,
|
|
98
|
+
help="Path to the JSON file that describes schema")
|
|
99
|
+
parser.add_argument('--to', dest='to', type=str, default=None, choices=["csv", "sqlite"])
|
|
100
|
+
parser.add_argument('--output', dest='output', type=str, default=None)
|
|
101
|
+
parser.add_argument('--limit', dest='limit', type=int, default=None,
|
|
102
|
+
help="Limit amount of source texts for prompting.")
|
|
103
|
+
parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
|
|
104
|
+
help="Optional trimming prompt by the specified amount of characters.")
|
|
105
|
+
|
|
106
|
+
# Extract native arguments.
|
|
107
|
+
native_args = CmdArgsService.extract_native_args(sys.argv, end_prefix="%%")
|
|
108
|
+
args = parser.parse_args(args=native_args[1:])
|
|
109
|
+
|
|
110
|
+
# Extract csv-related arguments.
|
|
111
|
+
csv_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%csv", end_prefix="%%")
|
|
112
|
+
csv_args_dict = CmdArgsService.args_to_dict(csv_args)
|
|
113
|
+
|
|
114
|
+
# Extract model-related arguments and Initialize Large Language Model.
|
|
115
|
+
model_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%m", end_prefix="%%")
|
|
116
|
+
model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": args.attempts}
|
|
117
|
+
llm, llm_model_name = init_llm(**model_args_dict)
|
|
118
|
+
|
|
119
|
+
# Setup schema.
|
|
120
|
+
schema = SchemaService(json_data=JsonService.read(args.schema))
|
|
121
|
+
schema_name = schema.src.get("name", None)
|
|
122
|
+
if schema is not None:
|
|
123
|
+
logger.info(f"Using schema: {schema_name}")
|
|
124
|
+
|
|
125
|
+
input_providers = {
|
|
126
|
+
None: lambda _: chat_with_lm(llm, chain=schema.chain, model_name=llm_model_name),
|
|
127
|
+
"csv": lambda filepath: CsvService.read(src=filepath, row_id_key=args.id_col,
|
|
128
|
+
as_dict=True, skip_header=True,
|
|
129
|
+
delimiter=csv_args_dict.get("delimiter", ","),
|
|
130
|
+
escapechar=csv_args_dict.get("escapechar", None)),
|
|
131
|
+
"tsv": lambda filepath: CsvService.read(src=filepath, row_id_key=args.id_col,
|
|
132
|
+
as_dict=True, skip_header=True,
|
|
133
|
+
delimiter=csv_args_dict.get("delimiter", "\t"),
|
|
134
|
+
escapechar=csv_args_dict.get("escapechar", None)),
|
|
135
|
+
"jsonl": lambda filepath: JsonlService.read(src=filepath, row_id_key=args.id_col)
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
output_providers = {
|
|
139
|
+
"csv": lambda filepath, data_it, header: CsvService.write(target=filepath,
|
|
140
|
+
data_it=data_it, header=header,
|
|
141
|
+
delimiter=csv_args_dict.get("delimiter", ","),
|
|
142
|
+
escapechar=csv_args_dict.get("escapechar", None),
|
|
143
|
+
it_type=None),
|
|
144
|
+
"tsv": lambda filepath, data_it, header: CsvService.write(target=filepath,
|
|
145
|
+
data_it=data_it, header=header,
|
|
146
|
+
delimiter=csv_args_dict.get("delimiter", "\t"),
|
|
147
|
+
escapechar=csv_args_dict.get("escapechar", None),
|
|
148
|
+
it_type=None),
|
|
149
|
+
"jsonl": lambda filepath, data_it, header:
|
|
150
|
+
JsonlService.write(target=filepath,
|
|
151
|
+
data_it=map(lambda item: {key: item[i] for i, key in enumerate(header)}, data_it))
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
# Setup output.
|
|
155
|
+
args.output = args.output.format(model=llm.name()) if args.output is not None else args.output
|
|
156
|
+
tgt_filepath, tgt_ext, tgt_meta = parse_filepath(args.output, default_ext=args.to)
|
|
157
|
+
|
|
158
|
+
# Input extension type defines the provider.
|
|
159
|
+
src_filepath, src_ext, src_meta = parse_filepath(args.src)
|
|
160
|
+
|
|
161
|
+
# Check whether we are in chat mode.
|
|
162
|
+
if src_ext is None:
|
|
163
|
+
input_providers[src_ext](None)
|
|
164
|
+
exit(0)
|
|
165
|
+
|
|
166
|
+
def default_output_file_template(ext):
|
|
167
|
+
# This is a default template for output files to be generated.
|
|
168
|
+
return "".join(["_".join([join(CWD, basename(src_filepath)), llm.name(), schema_name]), ext])
|
|
169
|
+
|
|
170
|
+
# Setup cache target as well as the related table.
|
|
171
|
+
cache_filepath = default_output_file_template(".sqlite") if tgt_filepath is None else tgt_filepath
|
|
172
|
+
cache_table = handle_table_name(tgt_meta if tgt_meta is not None else "contents")
|
|
173
|
+
|
|
174
|
+
# This is a content that we extracted via input provider.
|
|
175
|
+
it_data = input_providers[src_ext](src_filepath)
|
|
176
|
+
|
|
177
|
+
data_it = iter_content_cached(input_dicts_it=optional_limit_iter(it_data=it_data, limit=args.limit),
|
|
178
|
+
limit_prompt=args.limit_prompt,
|
|
179
|
+
schema=schema,
|
|
180
|
+
llm=llm,
|
|
181
|
+
id_column_name=args.id_col,
|
|
182
|
+
cache_target=":".join([cache_filepath, cache_table]))
|
|
183
|
+
|
|
184
|
+
# Setup output target
|
|
185
|
+
tgt_ext = src_ext if tgt_ext is None else tgt_ext
|
|
186
|
+
output_target = default_output_file_template(f".{tgt_ext}") if tgt_filepath is None else tgt_filepath
|
|
187
|
+
|
|
188
|
+
# Perform output writing process.
|
|
189
|
+
output_providers[tgt_ext](filepath=output_target,
|
|
190
|
+
data_it=data_it,
|
|
191
|
+
header=SQLite3Service.read_columns(target=cache_filepath, table=cache_table))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: bulk_chain
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.25.0
|
|
4
4
|
Summary: A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, ensuring reliable results for bulk input requests.
|
|
5
5
|
Home-page: https://github.com/nicolay-r/bulk-chain
|
|
6
6
|
Author: Nicolay Rusnachenko
|
|
@@ -15,33 +15,42 @@ Classifier: Topic :: Text Processing :: Linguistic
|
|
|
15
15
|
Requires-Python: >=3.6
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
17
|
License-File: LICENSE
|
|
18
|
-
Requires-Dist: tqdm
|
|
19
|
-
Requires-Dist: source-iter==0.24.2
|
|
20
18
|
|
|
21
|
-
# bulk-chain 0.
|
|
19
|
+
# bulk-chain 0.25.0
|
|
22
20
|

|
|
23
21
|
[](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
|
|
24
22
|
[](https://x.com/nicolayr_/status/1847969224636961033)
|
|
23
|
+
[](https://pypistats.org/packages/bulk-chain)
|
|
25
24
|
|
|
26
25
|
<p align="center">
|
|
27
26
|
<img src="logo.png"/>
|
|
28
27
|
</p>
|
|
29
28
|
|
|
30
|
-
A lightweight, no-strings-attached **[Chain-of-Thought](https://arxiv.org/abs/2201.11903)
|
|
31
|
-
It allows applying series of prompts formed into `schema` (See [related section](#chain-of-thought-schema))
|
|
29
|
+
A lightweight, no-strings-attached **framework** for your LLM that allows applying [Chain-of-Thought](https://arxiv.org/abs/2201.11903) prompt `schema` (See [related section](#chain-of-thought-schema)) towards a massive textual collections.
|
|
32
30
|
|
|
33
|
-
### Features
|
|
31
|
+
### Main Features
|
|
34
32
|
* ✅ **No-strings**: you're free to LLM dependencies and flexible `venv` customization.
|
|
35
|
-
* ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
|
|
36
|
-
* ✅ **Progress caching**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
|
|
37
33
|
* ✅ **Support schemas descriptions** for Chain-of-Thought concept.
|
|
34
|
+
* ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
|
|
35
|
+
|
|
36
|
+
### Extra Features
|
|
37
|
+
* ✅ **Progress caching [for remote LLMs]**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
|
|
38
|
+
|
|
38
39
|
|
|
39
40
|
# Installation
|
|
40
41
|
|
|
42
|
+
From PyPI:
|
|
43
|
+
|
|
41
44
|
```bash
|
|
42
45
|
pip install bulk-chain
|
|
43
46
|
```
|
|
44
47
|
|
|
48
|
+
or latest version from here:
|
|
49
|
+
|
|
50
|
+
```bash
|
|
51
|
+
pip install git+https://github.com/nicolay-r/bulk-chain@master
|
|
52
|
+
```
|
|
53
|
+
|
|
45
54
|
## Chain-of-Thought Schema
|
|
46
55
|
|
|
47
56
|
To declare Chain-of-Though (CoT) schema, this project exploits `JSON` format.
|
|
@@ -64,35 +73,37 @@ Below, is an example on how to declare your own schema:
|
|
|
64
73
|
}
|
|
65
74
|
```
|
|
66
75
|
|
|
67
|
-
Another templates are available [here](/ext/schema/
|
|
76
|
+
Another templates are available [here](/ext/schema/).
|
|
68
77
|
|
|
69
78
|
# Usage
|
|
70
79
|
|
|
71
|
-
|
|
80
|
+
Preliminary steps:
|
|
72
81
|
|
|
73
|
-
1. Define your [
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
82
|
+
1. Define your [schema](#chain-of-thought-schema) ([Example for Sentiment Analysis](/ext/schema/thor_cot_schema.json)))
|
|
83
|
+
2. Wrap or pick **LLM model** from the [list of presets](/ext/).
|
|
84
|
+
|
|
85
|
+
## API
|
|
86
|
+
|
|
87
|
+
Please take a look at the [**related Wiki page**](https://github.com/nicolay-r/bulk-chain/wiki)
|
|
88
|
+
|
|
89
|
+
## Shell
|
|
90
|
+
|
|
91
|
+
> **NOTE:** You have to install `source-iter` package
|
|
81
92
|
|
|
82
|
-
3. Launch inference in (chat mode):
|
|
83
93
|
```bash
|
|
84
|
-
|
|
85
|
-
--
|
|
86
|
-
--
|
|
87
|
-
|
|
88
|
-
|
|
94
|
+
python3 -m bulk_chain.infer \
|
|
95
|
+
--src "<PATH-TO-YOUR-CSV-or-JSONL>" \
|
|
96
|
+
--schema "ext/schema/default.json" \
|
|
97
|
+
--adapter "dynamic:ext/replicate.py:Replicate" \
|
|
98
|
+
%%m \
|
|
99
|
+
--api_token "<REPLICATE-API-TOKEN>" \
|
|
89
100
|
--temp 0.1
|
|
90
101
|
```
|
|
91
102
|
|
|
92
103
|
# Embed your LLM
|
|
93
104
|
|
|
94
105
|
All you have to do is to implement `BaseLM` class, that includes:
|
|
95
|
-
* `__init__` -- for
|
|
106
|
+
* `__init__` -- for setting up *batching mode support* and (optional) *model name*;
|
|
96
107
|
* `ask(prompt)` -- infer your model with the given `prompt`.
|
|
97
108
|
|
|
98
109
|
See examples with models [here](/ext).
|
|
@@ -2,19 +2,22 @@ LICENSE
|
|
|
2
2
|
README.md
|
|
3
3
|
setup.py
|
|
4
4
|
bulk_chain/__init__.py
|
|
5
|
+
bulk_chain/api.py
|
|
5
6
|
bulk_chain/infer.py
|
|
6
7
|
bulk_chain.egg-info/PKG-INFO
|
|
7
8
|
bulk_chain.egg-info/SOURCES.txt
|
|
8
9
|
bulk_chain.egg-info/dependency_links.txt
|
|
9
|
-
bulk_chain.egg-info/requires.txt
|
|
10
10
|
bulk_chain.egg-info/top_level.txt
|
|
11
11
|
bulk_chain/core/__init__.py
|
|
12
12
|
bulk_chain/core/llm_base.py
|
|
13
13
|
bulk_chain/core/service_args.py
|
|
14
|
+
bulk_chain/core/service_batch.py
|
|
14
15
|
bulk_chain/core/service_data.py
|
|
16
|
+
bulk_chain/core/service_dict.py
|
|
15
17
|
bulk_chain/core/service_json.py
|
|
16
18
|
bulk_chain/core/service_llm.py
|
|
17
19
|
bulk_chain/core/service_schema.py
|
|
18
20
|
bulk_chain/core/utils.py
|
|
21
|
+
test/test_api.py
|
|
19
22
|
test/test_args_seeking.py
|
|
20
23
|
test/test_cmdargs.py
|
|
@@ -15,7 +15,7 @@ def get_requirements(filenames):
|
|
|
15
15
|
|
|
16
16
|
setup(
|
|
17
17
|
name='bulk_chain',
|
|
18
|
-
version='0.
|
|
18
|
+
version='0.25.0',
|
|
19
19
|
python_requires=">=3.6",
|
|
20
20
|
description='A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, '
|
|
21
21
|
'ensuring reliable results for bulk input requests.',
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from os.path import join
|
|
3
|
+
|
|
4
|
+
from bulk_chain.api import iter_content, CWD
|
|
5
|
+
from bulk_chain.core.utils import dynamic_init
|
|
6
|
+
from bulk_chain.infer import iter_content_cached
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestAPI(unittest.TestCase):
|
|
10
|
+
|
|
11
|
+
llm = dynamic_init(class_dir=join(CWD, ".."),
|
|
12
|
+
class_filepath="ext/replicate.py",
|
|
13
|
+
class_name="Replicate")(api_token="<API-KEY>")
|
|
14
|
+
|
|
15
|
+
def it_data(self, n):
|
|
16
|
+
for i in range(n):
|
|
17
|
+
yield {"ind": i, "text": "X invent sanctions against Y", "entity": "X"}
|
|
18
|
+
|
|
19
|
+
def test_iter_cached(self):
|
|
20
|
+
data_it = iter_content_cached(input_dicts_it=self.it_data(20),
|
|
21
|
+
llm=self.llm,
|
|
22
|
+
schema="../ext/schema/default.json",
|
|
23
|
+
# Cache-related extra parameters.
|
|
24
|
+
cache_target="out.sqlite:content",
|
|
25
|
+
id_column_name="ind")
|
|
26
|
+
|
|
27
|
+
for data in data_it:
|
|
28
|
+
print(data)
|
|
29
|
+
|
|
30
|
+
def test_iter(self):
|
|
31
|
+
data_it = iter_content(input_dicts_it=self.it_data(20),
|
|
32
|
+
llm=self.llm,
|
|
33
|
+
batch_size=1,
|
|
34
|
+
return_batch=True,
|
|
35
|
+
schema="../ext/schema/default.json")
|
|
36
|
+
|
|
37
|
+
for data in data_it:
|
|
38
|
+
print(data)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
if __name__ == '__main__':
|
|
42
|
+
unittest.main()
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
from bulk_chain.core.service_args import CmdArgsService
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# Csv-related.
|
|
7
|
+
csv_args = CmdArgsService.find_grouped_args(sys.argv, starts_with="%%csv", end_prefix="%%")
|
|
8
|
+
print(csv_args)
|
|
9
|
+
csv_args = CmdArgsService.args_to_dict(csv_args)
|
|
10
|
+
print("csv\t", csv_args)
|
|
11
|
+
|
|
12
|
+
# Model-related.
|
|
13
|
+
m_args = CmdArgsService.find_grouped_args(sys.argv, starts_with="%%m", end_prefix="%%")
|
|
14
|
+
m_args = CmdArgsService.args_to_dict(m_args)
|
|
15
|
+
print("mod\t", m_args)
|
|
16
|
+
|
|
17
|
+
# native.
|
|
18
|
+
n_args = CmdArgsService.extract_native_args(sys.argv, end_prefix="%%")
|
|
19
|
+
n_args = CmdArgsService.args_to_dict(n_args)
|
|
20
|
+
print("nat\t", n_args)
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import time
|
|
3
|
-
|
|
4
|
-
from bulk_chain.core.utils import format_model_name
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class BaseLM(object):
|
|
8
|
-
|
|
9
|
-
def __init__(self, name, attempts=None, delay_sec=1, enable_log=True, **kwargs):
|
|
10
|
-
self.__name = name
|
|
11
|
-
self.__attempts = 1 if attempts is None else attempts
|
|
12
|
-
self.__delay_sec = delay_sec
|
|
13
|
-
|
|
14
|
-
if enable_log:
|
|
15
|
-
self.__logger = logging.getLogger(__name__)
|
|
16
|
-
logging.basicConfig(level=logging.INFO)
|
|
17
|
-
|
|
18
|
-
def ask_safe(self, prompt):
|
|
19
|
-
|
|
20
|
-
for i in range(self.__attempts):
|
|
21
|
-
try:
|
|
22
|
-
response = self.ask(prompt)
|
|
23
|
-
return response
|
|
24
|
-
except:
|
|
25
|
-
if self.__logger is not None:
|
|
26
|
-
self.__logger.info("Unable to infer the result. Try {} out of {}.".format(i, self.__attempts))
|
|
27
|
-
time.sleep(self.__delay_sec)
|
|
28
|
-
|
|
29
|
-
raise Exception("Can't infer")
|
|
30
|
-
|
|
31
|
-
def ask(self, prompt):
|
|
32
|
-
raise NotImplemented()
|
|
33
|
-
|
|
34
|
-
def name(self):
|
|
35
|
-
return format_model_name(self.__name)
|
|
@@ -1,176 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from os.path import join, basename
|
|
3
|
-
|
|
4
|
-
import argparse
|
|
5
|
-
import logging
|
|
6
|
-
import sys
|
|
7
|
-
|
|
8
|
-
from tqdm import tqdm
|
|
9
|
-
|
|
10
|
-
from source_iter.service_csv import CsvService
|
|
11
|
-
from source_iter.service_jsonl import JsonlService
|
|
12
|
-
from source_iter.service_sqlite import SQLite3Service
|
|
13
|
-
|
|
14
|
-
from bulk_chain.core.llm_base import BaseLM
|
|
15
|
-
from bulk_chain.core.service_args import CmdArgsService
|
|
16
|
-
from bulk_chain.core.service_data import DataService
|
|
17
|
-
from bulk_chain.core.service_json import JsonService
|
|
18
|
-
from bulk_chain.core.service_llm import chat_with_lm
|
|
19
|
-
from bulk_chain.core.service_schema import SchemaService
|
|
20
|
-
from bulk_chain.core.utils import dynamic_init, find_by_prefix, handle_table_name, optional_limit_iter, parse_filepath
|
|
21
|
-
|
|
22
|
-
logger = logging.getLogger(__name__)
|
|
23
|
-
logging.basicConfig(level=logging.INFO)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
CWD = os.getcwd()
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def init_llm(**model_kwargs):
|
|
30
|
-
""" This method perform dynamic initialization of LLM from third-party resource.
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
# List of the Supported models and their API wrappers.
|
|
34
|
-
models_preset = {
|
|
35
|
-
"dynamic": lambda: dynamic_init(class_dir=CWD, class_filepath=llm_model_name,
|
|
36
|
-
class_name=llm_model_params)(**model_kwargs)
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
# Initialize LLM model.
|
|
40
|
-
params = args.adapter.split(':')
|
|
41
|
-
llm_model_type = params[0]
|
|
42
|
-
llm_model_name = params[1] if len(params) > 1 else params[-1]
|
|
43
|
-
llm_model_params = ':'.join(params[2:]) if len(params) > 2 else None
|
|
44
|
-
llm = find_by_prefix(d=models_preset, key=llm_model_type)()
|
|
45
|
-
|
|
46
|
-
return llm, llm_model_name
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def init_schema(json_filepath):
|
|
50
|
-
return SchemaService(json_data=JsonService.read(json_filepath))
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def iter_content(input_dicts_iter, llm, schema, cache_target, cache_table, id_column_name):
|
|
54
|
-
""" This method represent Python API aimed at application of `llm` towards
|
|
55
|
-
iterator of input_dicts via cache_target that refers to the SQLite using
|
|
56
|
-
the given `schema`
|
|
57
|
-
"""
|
|
58
|
-
assert (isinstance(llm, BaseLM))
|
|
59
|
-
assert (isinstance(schema, SchemaService))
|
|
60
|
-
assert (isinstance(cache_target, str))
|
|
61
|
-
assert (isinstance(cache_table, str))
|
|
62
|
-
|
|
63
|
-
infer_modes = {
|
|
64
|
-
"default": lambda prompt: llm.ask_safe(prompt[:args.limit_prompt] if args.limit_prompt is not None else prompt)
|
|
65
|
-
}
|
|
66
|
-
|
|
67
|
-
def optional_update_data_records(c, data):
|
|
68
|
-
assert (isinstance(c, str))
|
|
69
|
-
|
|
70
|
-
if c in schema.p2r:
|
|
71
|
-
data[c] = DataService.get_prompt_text(prompt=data[c]["prompt"], data_dict=data)
|
|
72
|
-
if c in schema.r2p:
|
|
73
|
-
p_column = schema.r2p[c]
|
|
74
|
-
# This instruction takes a lot of time in a non-batching mode.
|
|
75
|
-
data[c] = infer_modes["default"](data[p_column])
|
|
76
|
-
|
|
77
|
-
return data[c]
|
|
78
|
-
|
|
79
|
-
cache_providers = {
|
|
80
|
-
"sqlite": lambda filepath, table_name, data_it: SQLite3Service.write_missed(
|
|
81
|
-
data_it=data_it, target=filepath,
|
|
82
|
-
data2col_func=optional_update_data_records,
|
|
83
|
-
table_name=handle_table_name(table_name if table_name is not None else "contents"),
|
|
84
|
-
id_column_name=id_column_name)
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
# We optionally wrap into limiter.
|
|
88
|
-
queries_it = optional_limit_iter(
|
|
89
|
-
it_data=map(lambda data: data.update(schema.cot_args) or data, input_dicts_iter),
|
|
90
|
-
limit=args.limit)
|
|
91
|
-
|
|
92
|
-
# Provide data caching.
|
|
93
|
-
cache_providers["sqlite"](cache_target, table_name=tgt_meta, data_it=tqdm(queries_it, desc="Iter content"))
|
|
94
|
-
|
|
95
|
-
return SQLite3Service.read(cache_target, table=cache_table)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
if __name__ == '__main__':
|
|
99
|
-
|
|
100
|
-
parser = argparse.ArgumentParser(description="Infer Instruct LLM inference based on CoT schema")
|
|
101
|
-
parser.add_argument('--adapter', dest='adapter', type=str, default=None)
|
|
102
|
-
parser.add_argument('--attempts', dest='attempts', type=int, default=None)
|
|
103
|
-
parser.add_argument('--id-col', dest='id_col', type=str, default="uid")
|
|
104
|
-
parser.add_argument('--src', dest='src', type=str, default=None)
|
|
105
|
-
parser.add_argument('--schema', dest='schema', type=str, default=None,
|
|
106
|
-
help="Path to the JSON file that describes schema")
|
|
107
|
-
parser.add_argument('--to', dest='to', type=str, default=None, choices=["csv", "sqlite"])
|
|
108
|
-
parser.add_argument('--output', dest='output', type=str, default=None)
|
|
109
|
-
parser.add_argument('--limit', dest='limit', type=int, default=None,
|
|
110
|
-
help="Limit amount of source texts for prompting.")
|
|
111
|
-
parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
|
|
112
|
-
help="Optional trimming prompt by the specified amount of characters.")
|
|
113
|
-
|
|
114
|
-
native_args, model_args = CmdArgsService.partition_list(lst=sys.argv, sep="%%")
|
|
115
|
-
|
|
116
|
-
args = parser.parse_args(args=native_args[1:])
|
|
117
|
-
|
|
118
|
-
# Initialize Large Language Model.
|
|
119
|
-
model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": args.attempts}
|
|
120
|
-
llm, llm_model_name = init_llm(**model_args_dict)
|
|
121
|
-
|
|
122
|
-
# Setup schema.
|
|
123
|
-
schema = init_schema(args.schema)
|
|
124
|
-
if schema is not None:
|
|
125
|
-
logger.info(f"Using schema: {schema.name}")
|
|
126
|
-
|
|
127
|
-
input_providers = {
|
|
128
|
-
None: lambda _: chat_with_lm(llm, chain=schema.chain, model_name=llm_model_name),
|
|
129
|
-
"csv": lambda filepath: CsvService.read(src=filepath, row_id_key=args.id_col,
|
|
130
|
-
as_dict=True, skip_header=True,
|
|
131
|
-
delimiter=model_args_dict.get("delimiter", "\t"),
|
|
132
|
-
escapechar=model_args_dict.get("escapechar", None)),
|
|
133
|
-
"jsonl": lambda filepath: JsonlService.read(src=filepath, row_id_key=args.id_col)
|
|
134
|
-
}
|
|
135
|
-
|
|
136
|
-
output_providers = {
|
|
137
|
-
"csv": lambda filepath, data_it, header:
|
|
138
|
-
CsvService.write(target=filepath, data_it=data_it, header=header, it_type=None),
|
|
139
|
-
"jsonl": lambda filepath, data_it, header:
|
|
140
|
-
JsonlService.write(target=filepath,
|
|
141
|
-
data_it=map(lambda item: {key: item[i] for i, key in enumerate(header)}, data_it))
|
|
142
|
-
}
|
|
143
|
-
|
|
144
|
-
# Setup output.
|
|
145
|
-
args.output = args.output.format(model=llm.name()) if args.output is not None else args.output
|
|
146
|
-
tgt_filepath, tgt_ext, tgt_meta = parse_filepath(args.output, default_ext=args.to)
|
|
147
|
-
|
|
148
|
-
# Input extension type defines the provider.
|
|
149
|
-
src_filepath, src_ext, src_meta = parse_filepath(args.src)
|
|
150
|
-
|
|
151
|
-
# Check whether we are in chat mode.
|
|
152
|
-
if src_ext is None:
|
|
153
|
-
input_providers[src_ext](None)
|
|
154
|
-
exit(0)
|
|
155
|
-
|
|
156
|
-
# Setup cache target as well as the related table.
|
|
157
|
-
cache_target = "".join(["_".join([join(CWD, basename(src_filepath)), llm.name(), schema.name]), f".sqlite"]) \
|
|
158
|
-
if tgt_filepath is None else tgt_filepath
|
|
159
|
-
cache_table = handle_table_name(tgt_meta if tgt_meta is not None else "contents")
|
|
160
|
-
|
|
161
|
-
data_it = iter_content(input_dicts_iter=input_providers[src_ext](src_filepath),
|
|
162
|
-
schema=schema,
|
|
163
|
-
llm=llm,
|
|
164
|
-
id_column_name=args.id_col,
|
|
165
|
-
cache_target=cache_target,
|
|
166
|
-
cache_table=cache_table)
|
|
167
|
-
|
|
168
|
-
# Setup output target
|
|
169
|
-
tgt_ext = src_ext if tgt_ext is None else tgt_ext
|
|
170
|
-
output_target = "".join(["_".join([join(CWD, basename(src_filepath)), llm.name(), schema.name]), f".{tgt_ext}"]) \
|
|
171
|
-
if tgt_filepath is None else tgt_filepath
|
|
172
|
-
|
|
173
|
-
# Perform output writing process.
|
|
174
|
-
output_providers[tgt_ext](filepath=output_target,
|
|
175
|
-
data_it=data_it,
|
|
176
|
-
header=SQLite3Service.read_columns(target=cache_target, table=cache_table))
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|