bulk-chain 0.25.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bulk_chain/api.py CHANGED
@@ -1,17 +1,17 @@
1
+ import collections
1
2
  import os
2
3
  from itertools import chain
3
4
 
4
5
  from bulk_chain.core.llm_base import BaseLM
5
- from bulk_chain.core.service_batch import BatchIterator, BatchService
6
+ from bulk_chain.core.service_batch import BatchIterator
6
7
  from bulk_chain.core.service_data import DataService
7
8
  from bulk_chain.core.service_dict import DictionaryService
8
9
  from bulk_chain.core.service_json import JsonService
9
10
  from bulk_chain.core.service_schema import SchemaService
10
11
  from bulk_chain.core.utils import dynamic_init, find_by_prefix
11
12
 
13
+
12
14
  INFER_MODES = {
13
- "default": lambda llm, prompt, limit_prompt=None: llm.ask_core(
14
- prompt[:limit_prompt] if limit_prompt is not None else prompt),
15
15
  "batch": lambda llm, batch, limit_prompt=None: llm.ask_core(
16
16
  DataService.limit_prompts(batch, limit=limit_prompt))
17
17
  }
@@ -20,44 +20,85 @@ INFER_MODES = {
20
20
  CWD = os.getcwd()
21
21
 
22
22
 
23
- def _update_batch_content(c, batch, schema, infer_func):
24
- assert (isinstance(batch, list))
25
- assert (isinstance(c, str))
23
+ def _iter_entry_content(entry, entry_info=None, **kwargs):
24
+
25
+ if isinstance(entry, str):
26
+ kwargs.get("callback_str_func", lambda *_: None)(entry, entry_info)
27
+ yield entry
28
+ elif isinstance(entry, collections.abc.Iterable):
29
+ h = kwargs.get("callback_stream_func", lambda *_: None)
30
+ h(None, entry_info | {"action": "start"})
31
+ for chunk in map(lambda item: str(item), entry):
32
+ yield chunk
33
+ h(chunk, entry_info)
34
+ h(None, entry_info | {"action": "end"})
35
+ else:
36
+ raise Exception(f"Non supported type `{type(entry)}` for handling output from batch")
37
+
38
+
39
+ def _iter_batch_prompts(c, batch_content_it, **kwargs):
40
+ for ind_in_batch, entry in enumerate(batch_content_it):
41
+ content = DataService.get_prompt_text(
42
+ prompt=entry[c]["prompt"],
43
+ data_dict=entry,
44
+ handle_missed_func=kwargs["handle_missed_value_func"])
45
+ yield ind_in_batch, content
46
+
26
47
 
27
- if c in schema.p2r:
28
- for batch_item in batch:
29
- batch_item[c] = DataService.get_prompt_text(prompt=batch_item[c]["prompt"], data_dict=batch_item)
30
- if c in schema.r2p:
31
- p_column = schema.r2p[c]
32
- # This instruction takes a lot of time in a non-batching mode.
33
- BatchService.handle_param_as_batch(batch=batch,
34
- src_param=p_column,
35
- tgt_param=c,
36
- handle_func=lambda b: infer_func(b))
48
+ def _iter_batch_responses(p_column, c, batch_content_it, **kwargs):
49
+ p_batch = [item[p_column] for item in batch_content_it]
50
+ # TODO. This part could be async.
51
+ # TODO. ind_in_batch might be a part of the async return.
52
+ for ind_in_batch, entry in enumerate(kwargs["handle_batch_func"](p_batch)):
53
+ yield ind_in_batch, _iter_entry_content(entry=entry, entry_info={"ind": ind_in_batch, "param": c}, **kwargs)
37
54
 
38
55
 
39
- def _infer_batch(batch, schema, infer_func, cols=None):
56
+ def _infer_batch(batch, schema, return_mode, cols=None, **kwargs):
40
57
  assert (isinstance(batch, list))
41
- assert (callable(infer_func))
42
58
 
43
59
  if len(batch) == 0:
44
60
  return batch
45
61
 
46
62
  if cols is None:
47
63
  first_item = batch[0]
48
- cols = first_item.keys() if cols is None else cols
64
+ cols = list(first_item.keys()) if cols is None else cols
49
65
 
50
66
  for c in cols:
51
- _update_batch_content(c=c, batch=batch, schema=schema, infer_func=infer_func)
52
67
 
53
- return batch
68
+ # Handling prompt column.
69
+ if c in schema.p2r:
70
+ content_it = _iter_batch_prompts(c=c, batch_content_it=iter(batch), **kwargs)
71
+ for ind_in_batch, prompt in content_it:
72
+ batch[ind_in_batch][c] = prompt
73
+
74
+ # Handling column for inference.
75
+ if c in schema.r2p:
76
+ content_it = _iter_batch_responses(c=c, p_column=schema.r2p[c], batch_content_it=iter(batch), **kwargs)
77
+ for ind_in_batch, chunk_it in content_it:
78
+
79
+ chunks = []
80
+ for chunk in chunk_it:
81
+ chunks.append(chunk)
82
+
83
+ if return_mode == "chunk":
84
+ yield [ind_in_batch, c, chunk]
85
+
86
+ batch[ind_in_batch][c] = "".join(chunks)
87
+
88
+ if return_mode == "record":
89
+ for record in batch:
90
+ yield record
91
+
92
+ if return_mode == "batch":
93
+ yield batch
54
94
 
55
95
 
56
- def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, limit_prompt=None):
96
+ def iter_content(input_dicts_it, llm, schema, batch_size=1, limit_prompt=None, return_mode="batch", **kwargs):
57
97
  """ This method represent Python API aimed at application of `llm` towards
58
98
  iterator of input_dicts via cache_target that refers to the SQLite using
59
99
  the given `schema`
60
100
  """
101
+ assert (return_mode in ["batch", "chunk"])
61
102
  assert (isinstance(llm, BaseLM))
62
103
 
63
104
  # Quick initialization of the schema.
@@ -67,21 +108,24 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, l
67
108
  schema = SchemaService(json_data=schema)
68
109
 
69
110
  prompts_it = map(
70
- lambda data: DictionaryService.custom_update(src_dict=data, other_dict=schema.cot_args),
111
+ lambda data: DictionaryService.custom_update(src_dict=dict(data), other_dict=schema.cot_args),
71
112
  input_dicts_it
72
113
  )
73
114
 
74
115
  content_it = (_infer_batch(batch=batch,
75
- infer_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
76
- schema=schema)
116
+ handle_batch_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
117
+ return_mode=return_mode,
118
+ schema=schema,
119
+ **kwargs)
77
120
  for batch in BatchIterator(prompts_it, batch_size=batch_size))
78
121
 
79
- yield from content_it if return_batch else chain.from_iterable(content_it)
122
+ yield from chain.from_iterable(content_it)
80
123
 
81
124
 
82
125
  def init_llm(adapter, **model_kwargs):
83
126
  """ This method perform dynamic initialization of LLM from third-party resource.
84
127
  """
128
+ assert (isinstance(adapter, str))
85
129
 
86
130
  # List of the Supported models and their API wrappers.
87
131
  models_preset = {
@@ -1,8 +1,6 @@
1
1
  import logging
2
2
  import time
3
3
 
4
- from bulk_chain.core.utils import format_model_name
5
-
6
4
 
7
5
  class BaseLM(object):
8
6
 
@@ -49,4 +47,4 @@ class BaseLM(object):
49
47
  raise NotImplemented()
50
48
 
51
49
  def name(self):
52
- return format_model_name(self.__name)
50
+ return self.__name.replace("/", "_")
@@ -1,31 +1,13 @@
1
- class BatchService(object):
2
-
3
- @staticmethod
4
- def handle_param_as_batch(batch, src_param, tgt_param, handle_func):
5
- assert (isinstance(batch, list))
6
- assert (isinstance(src_param, str))
7
- assert (callable(handle_func))
8
-
9
- _batch = [item[src_param] for item in batch]
10
-
11
- # Do handling for the batch.
12
- _handled_batch = handle_func(_batch)
13
- assert (isinstance(_handled_batch, list))
14
-
15
- # Apply changes.
16
- for i, item in enumerate(batch):
17
- item[tgt_param] = _handled_batch[i]
18
-
19
-
20
1
  class BatchIterator:
21
2
 
22
- def __init__(self, data_iter, batch_size, end_value=None):
3
+ def __init__(self, data_iter, batch_size, end_value=None, filter_func=None):
23
4
  assert(isinstance(batch_size, int) and batch_size > 0)
24
5
  assert(callable(end_value) or end_value is None)
25
6
  self.__data_iter = data_iter
26
7
  self.__index = 0
27
8
  self.__batch_size = batch_size
28
9
  self.__end_value = end_value
10
+ self.__filter_func = (lambda _: True) if filter_func is None else filter_func
29
11
 
30
12
  def __iter__(self):
31
13
  return self
@@ -37,7 +19,8 @@ class BatchIterator:
37
19
  data = next(self.__data_iter)
38
20
  except StopIteration:
39
21
  break
40
- buffer.append(data)
22
+ if self.__filter_func(data):
23
+ buffer.append(data)
41
24
  if len(buffer) == self.__batch_size:
42
25
  break
43
26
 
@@ -4,8 +4,8 @@ from bulk_chain.core.utils import iter_params
4
4
  class DataService(object):
5
5
 
6
6
  @staticmethod
7
- def compose_prompt_text(prompt, data_dict, field_names):
8
- assert(isinstance(data_dict, dict))
7
+ def __compose_prompt_text(prompt, data_dict, field_names):
8
+ assert (isinstance(data_dict, dict))
9
9
  fmt_d = {col_name: data_dict[col_name] for col_name in field_names}
10
10
 
11
11
  # Guarantee that items has correct type.
@@ -16,10 +16,14 @@ class DataService(object):
16
16
  return prompt.format(**fmt_d)
17
17
 
18
18
  @staticmethod
19
- def get_prompt_text(prompt, data_dict, parse_fields_func=iter_params):
19
+ def get_prompt_text(prompt, data_dict, parse_fields_func=iter_params, handle_missed_func=None):
20
20
  field_names = list(parse_fields_func(prompt))
21
- return DataService.compose_prompt_text(
22
- prompt=prompt, data_dict=data_dict, field_names=field_names)
21
+
22
+ for col_name in field_names:
23
+ if col_name not in data_dict:
24
+ data_dict[col_name] = handle_missed_func(col_name)
25
+
26
+ return DataService.__compose_prompt_text(prompt=prompt, data_dict=data_dict, field_names=field_names)
23
27
 
24
28
  @staticmethod
25
29
  def limit_prompts(prompts_list, limit=None):
bulk_chain/core/utils.py CHANGED
@@ -2,6 +2,7 @@ import importlib
2
2
  import logging
3
3
  import sys
4
4
  from collections import Counter
5
+ from os.path import dirname, join, basename
5
6
 
6
7
  logger = logging.getLogger(__name__)
7
8
  logging.basicConfig(level=logging.INFO)
@@ -47,28 +48,6 @@ def iter_params(text):
47
48
  beg = pe+1
48
49
 
49
50
 
50
- def format_model_name(name):
51
- return name.replace("/", "_")
52
-
53
-
54
- def parse_filepath(filepath, default_filepath=None, default_ext=None):
55
- """ This is an auxiliary function for handling sources and targets from cmd string.
56
- """
57
- if filepath is None:
58
- return default_filepath, default_ext, None
59
- info = filepath.split(":")
60
- filepath = info[0]
61
- meta = info[1] if len(info) > 1 else None
62
- ext = filepath.split('.')[-1] if default_ext is None else default_ext
63
- return filepath, ext, meta
64
-
65
-
66
- def handle_table_name(name):
67
- return name.\
68
- replace('-', '_').\
69
- replace('.', "_")
70
-
71
-
72
51
  def auto_import(name, is_class=False):
73
52
  """ Import from the external python packages.
74
53
  """
@@ -82,13 +61,24 @@ def auto_import(name, is_class=False):
82
61
 
83
62
 
84
63
  def dynamic_init(class_dir, class_filepath, class_name=None):
85
- sys.path.append(class_dir)
64
+
65
+ # Registering path.
66
+ target = join(class_dir, dirname(class_filepath))
67
+ logger.info(f"Adding sys path for `{target}`")
68
+ sys.path.insert(1, target)
86
69
  class_path_list = class_filepath.split('/')
87
- class_path_list[-1] = '.'.join(class_path_list[-1].split('.')[:-1])
70
+
71
+ # Composing proper class name.
72
+ class_filename = basename(class_path_list[-1])
73
+ if class_filename.endswith(".py"):
74
+ class_filename = class_filename[:-len(".py")]
75
+
76
+ # Loading library.
88
77
  class_name = class_path_list[-1].title() if class_name is None else class_name
89
- class_path = ".".join(class_path_list + [class_name])
78
+ class_path = ".".join([class_filename, class_name])
90
79
  logger.info(f"Dynamic loading for the file and class `{class_path}`")
91
80
  cls = auto_import(class_path, is_class=False)
81
+
92
82
  return cls
93
83
 
94
84
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bulk_chain
3
- Version: 0.25.2
3
+ Version: 1.0.0
4
4
  Summary: A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, ensuring reliable results for bulk input requests.
5
5
  Home-page: https://github.com/nicolay-r/bulk-chain
6
6
  Author: Nicolay Rusnachenko
@@ -16,9 +16,8 @@ Requires-Python: >=3.6
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
18
  Requires-Dist: tqdm
19
- Requires-Dist: source-iter ==0.24.3
20
19
 
21
- # bulk-chain 0.25.2
20
+ # bulk-chain 1.0.0
22
21
  ![](https://img.shields.io/badge/Python-3.9-brightgreen.svg)
23
22
  [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
24
23
  [![twitter](https://img.shields.io/twitter/url/https/shields.io.svg?style=social)](https://x.com/nicolayr_/status/1847969224636961033)
@@ -31,7 +30,7 @@ Requires-Dist: source-iter ==0.24.3
31
30
  <p align="center">
32
31
  <a href="https://github.com/nicolay-r/nlp-thirdgate?tab=readme-ov-file#llm"><b>Third-party providers hosting</b>↗️</a>
33
32
  <br>
34
- <a href="https://github.com/nicolay-r/bulk-chain/blob/master/README.md#demo-mode">👉<b>demo</b>👈</a>
33
+ <a href="https://github.com/nicolay-r/bulk-chain-shell">👉<b>demo</b>👈</a>
35
34
  </p>
36
35
 
37
36
  A no-strings-attached **framework** for your LLM that allows applying Chain-of-Thought-alike [prompt `schema`](#chain-of-thought-schema) towards a massive textual collections using custom **[third-party providers ↗️](https://github.com/nicolay-r/nlp-thirdgate?tab=readme-ov-file#llm)**.
@@ -39,11 +38,7 @@ A no-strings-attached **framework** for your LLM that allows applying Chain-of-
39
38
  ### Main Features
40
39
  * ✅ **No-strings**: you're free to LLM dependencies and flexible `venv` customization.
41
40
  * ✅ **Support schemas descriptions** for Chain-of-Thought concept.
42
- * ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
43
-
44
- ### Extra Features
45
- * ✅ **Progress caching [for remote LLMs]**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
46
-
41
+ * ✅ **Provides iterator over infinite amount of input contexts**
47
42
 
48
43
  # Installation
49
44
 
@@ -88,51 +83,8 @@ Preliminary steps:
88
83
  1. Define your [schema](#chain-of-thought-schema) ([Example for Sentiment Analysis](/ext/schema/thor_cot_schema.json)))
89
84
  2. Wrap or pick **LLM model** from the [<b>Third-party providers hosting</b>↗️](https://github.com/nicolay-r/nlp-thirdgate?tab=readme-ov-file#llm).
90
85
 
91
- ## Shell
92
-
93
- ### Demo Mode
94
-
95
- **demo mode** to interact with LLM via command line with LLM output streaming support.
96
- The video below illustrates an example of application for sentiment analysis on author opinion extraction towards mentioned object in text.
97
-
98
- Quck start with launching demo:
99
- 1. ⬇️ Download [replicate](https://replicate.com/) provider for `bulk-chain`:
100
- 2. 📜 Setup your reasoning `thor_cot_schema.json` according to the [following example ↗️](test/schema/thor_cot_schema.json)
101
- 3. 🚀 Launch `demo.py` as follows:
102
- ```bash
103
- python3 -m bulk_chain.demo \
104
- --schema "test/schema/thor_cot_schema.json" \
105
- --adapter "dynamic:replicate_104.py:Replicate" \
106
- %%m \
107
- --model_name "meta/meta-llama-3-70b-instruct" \
108
- --api_token "<REPLICATE-API-TOKEN>" \
109
- --stream
110
- ```
111
-
112
- 📺 This video showcase application of the [↗️ Sentiment Analysis Schema](https://github.com/nicolay-r/bulk-chain/blob/master/test/schema/thor_cot_schema.json) towards [LLaMA-3-70B-Instruct](https://replicate.com/meta/meta-llama-3-70b-instruct) hosted by Replicate for reasoning over submitted texts
113
- ![sa-bulk-chain-cot-final](https://github.com/user-attachments/assets/0cc8fdcb-6ddb-44a3-8f05-d76250ae6423)
114
86
 
115
87
 
116
- ### Inference Mode
117
-
118
- > **NOTE:** You have to install `source-iter` and `tqdm` packages that actual [dependencies](dependencies.txt) of this project
119
-
120
- 1. ⬇️ Download [replicate](https://replicate.com/) provider for `bulk-chain`:
121
- ```bash
122
- wget https://raw.githubusercontent.com/nicolay-r/nlp-thirdgate/refs/heads/master/llm/replicate_104.py
123
- ```
124
- 2. 📜 Setup your reasoning `schema.json` according to the [following example ↗️](test/schema/default.json)
125
- 3. 🚀 Launch inference using `DeepSeek-R1`:
126
- ```bash
127
- python3 -m bulk_chain.infer \
128
- --src "<PATH-TO-YOUR-CSV-or-JSONL>" \
129
- --schema "test/schema/default.json" \
130
- --adapter "replicate_104.py:Replicate" \
131
- %%m \
132
- --model_name "deepseek-ai/deepseek-r1" \
133
- --api_token "<REPLICATE-API-TOKEN>"
134
- ```
135
-
136
88
  ## API
137
89
 
138
90
  Please take a look at the [**related Wiki page**](https://github.com/nicolay-r/bulk-chain/wiki)
@@ -0,0 +1,15 @@
1
+ bulk_chain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ bulk_chain/api.py,sha256=d_c10Je8wUSnCdQjyWCHVx4FGW6M2_pBMMqKsI_YJaY,5119
3
+ bulk_chain/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ bulk_chain/core/llm_base.py,sha256=DZ9l4HpCs9uKTZp68miw_XCqmRAJBqQPuYSK889CeUk,1785
5
+ bulk_chain/core/service_batch.py,sha256=LMxrZeQXV_AJAoCaMCHVx8TvjcmCaKUQhNE8K4D8pCo,1031
6
+ bulk_chain/core/service_data.py,sha256=OWWHHnr_plwxYTxLuvMrhEc1PbSx-XC3rbFzV0hy3vk,1107
7
+ bulk_chain/core/service_dict.py,sha256=lAghLU-3V3xYGv5BTA327Qcw8UJYmgQRMFdggzlrUgo,383
8
+ bulk_chain/core/service_json.py,sha256=6o1xM_8c9QEjH9Q3qEmJylU9nahfRXhUd5sFF2dGJwo,182
9
+ bulk_chain/core/service_schema.py,sha256=KIP4n0Tz2h1i7SIMGhgAhoiCgUFXOT1rzMt38yACS2U,1154
10
+ bulk_chain/core/utils.py,sha256=Dx9Gy-jPpk-w_8WUekN0Ij4RBIWVAPg74vA3N0JgGqc,2471
11
+ bulk_chain-1.0.0.dist-info/LICENSE,sha256=VF9SjNpwwSSFEY_eP_8A1ocDCrbwfjI1pZexXdCkOwo,1076
12
+ bulk_chain-1.0.0.dist-info/METADATA,sha256=TR86CmhcHJ3Sep8TlHZ0Ede_PnH8G5iMILUvVvSskJY,3810
13
+ bulk_chain-1.0.0.dist-info/WHEEL,sha256=pL8R0wFFS65tNSRnaOVrsw9EOkOqxLrlUPenUYnJKNo,91
14
+ bulk_chain-1.0.0.dist-info/top_level.txt,sha256=Hxq_wyH-GDXKBaA63UfBIiMJO2eCHJG5EOrXDphpeB4,11
15
+ bulk_chain-1.0.0.dist-info/RECORD,,
@@ -1,72 +0,0 @@
1
- class CmdArgsService:
2
-
3
- @staticmethod
4
- def autocast(v):
5
- for t in [int, float, str]:
6
- try:
7
- return t(v)
8
- except:
9
- pass
10
-
11
- @staticmethod
12
- def iter_arguments(lst):
13
-
14
- def __release():
15
-
16
- # We use the True value by default to treat the related parameter as flag.
17
- if len(buf) == 0:
18
- buf.append(True)
19
-
20
- return key, buf if len(buf) > 1 else buf[0]
21
-
22
- key = None
23
- buf = []
24
- for a in lst:
25
- if a.startswith('--'):
26
- # release
27
- if key is not None:
28
- yield __release()
29
- # set new key and empty buf
30
- key = a[2:]
31
- buf = []
32
- else:
33
- # append argument into buffer.
34
- buf.append(a)
35
-
36
- # Sharing the remaining params.
37
- if key is not None:
38
- yield __release()
39
-
40
- @staticmethod
41
- def __find_suffix_ind(lst, idx_from, end_prefix):
42
- for i in range(idx_from, len(lst)):
43
- if lst[i].startswith(end_prefix):
44
- return i
45
- return len(lst)
46
-
47
- @staticmethod
48
- def extract_native_args(lst, end_prefix):
49
- return lst[:CmdArgsService.__find_suffix_ind(lst, idx_from=0, end_prefix=end_prefix)]
50
-
51
- @staticmethod
52
- def find_grouped_args(lst, starts_with, end_prefix):
53
- """Slices a list in two, cutting on index matching "sep"
54
- """
55
-
56
- # Checking the presence of starts_with.
57
- # We have to return empty content in the case of absence starts_with in the lst.
58
- if starts_with not in lst:
59
- return []
60
-
61
- # Assigning start index.
62
- idx_from = lst.index(starts_with) + 1
63
-
64
- # Assigning end index.
65
- idx_to = CmdArgsService.__find_suffix_ind(lst, idx_from=idx_from, end_prefix=end_prefix)
66
-
67
- return lst[idx_from:idx_to]
68
-
69
- @staticmethod
70
- def args_to_dict(args):
71
- return {k: CmdArgsService.autocast(v) if not isinstance(v, list) else v
72
- for k, v in CmdArgsService.iter_arguments(args)} if args is not None else {}
@@ -1,94 +0,0 @@
1
- from bulk_chain.core.llm_base import BaseLM
2
- from bulk_chain.core.service_data import DataService
3
- from bulk_chain.core.utils import iter_params
4
- from bulk_chain.core.utils_logger import StreamedLogger
5
-
6
-
7
- def pad_str(text, pad):
8
- return text.rjust(len(text) + pad, ' ')
9
-
10
-
11
- def nice_output(text, remove_new_line=False):
12
- short_text = text.replace("\n", "") if remove_new_line else text
13
- return short_text
14
-
15
-
16
- def chat_with_lm(lm, preset_dict=None, chain=None, model_name=None, pad=0):
17
- assert (isinstance(lm, BaseLM))
18
- assert (isinstance(chain, list))
19
- assert (isinstance(model_name, str) or model_name is None)
20
-
21
- preset_dict = {} if preset_dict is None else preset_dict
22
-
23
- streamed_logger = StreamedLogger(__name__)
24
-
25
- do_exit = False
26
- model_name = model_name if model_name is not None else "agent"
27
-
28
- while not do_exit:
29
-
30
- streamed_logger.info("----------------")
31
- streamed_logger.info("\n")
32
-
33
- # Launching the CoT engine loop.
34
- data_dict = {} | preset_dict
35
- for chain_ind, prompt_args in enumerate(chain):
36
-
37
- # Processing the prompt.
38
- prompt = prompt_args["prompt"]
39
-
40
- # Filling necessary parameters.
41
- user_informed = False
42
- field_names = list(iter_params(prompt))
43
- for ind, f_name in enumerate(field_names):
44
-
45
- if f_name in data_dict:
46
- continue
47
-
48
- user_input = input(f"Enter your prompt for `{f_name}` ({ind+1}/{len(field_names)}) "
49
- f"(or 'exit' to quit): ")
50
- user_informed = True
51
-
52
- if user_input.lower() == 'exit':
53
- do_exit = True
54
- break
55
-
56
- data_dict[f_name] = user_input
57
-
58
- if do_exit:
59
- break
60
-
61
- # In the case of the initial interaction with the chain.
62
- # we make sure that aware user for starting interaction.
63
- if chain_ind == 0 and not user_informed:
64
- user_input = input(f"Enter to continue (or 'exit' to quit) ...")
65
- if user_input.lower() == 'exit':
66
- do_exit = True
67
-
68
- # Finally asking LLM.
69
- DataService.compose_prompt_text(prompt=prompt, data_dict=data_dict, field_names=field_names)
70
- actual_prompt = DataService.get_prompt_text(prompt=prompt, data_dict=data_dict)
71
-
72
- # Returning meta information, passed to LLM.
73
- streamed_logger.info(pad_str(f"{model_name} (ask [{chain_ind+1}/{len(chain)}]) ->", pad=pad))
74
- streamed_logger.info("\n")
75
- streamed_logger.info(nice_output(actual_prompt, remove_new_line=True))
76
- streamed_logger.info("\n\n")
77
-
78
- # Response.
79
- response = lm.ask_core(batch=[actual_prompt])[0]
80
- streamed_logger.info(pad_str(f"{model_name} (resp [{chain_ind+1}/{len(chain)}])->", pad=pad))
81
- streamed_logger.info("\n")
82
- if isinstance(response, str):
83
- streamed_logger.info(nice_output(response, remove_new_line=False))
84
- buffer = [response]
85
- else:
86
- buffer = []
87
- for chunk in response:
88
- streamed_logger.info(chunk)
89
- buffer.append(str(chunk))
90
-
91
- streamed_logger.info("\n\n")
92
-
93
- # Collecting the answer for the next turn.
94
- data_dict[prompt_args["out"]] = "".join(buffer)
@@ -1,41 +0,0 @@
1
- import logging
2
-
3
-
4
- def StreamedLogger(name: str) -> logging.Logger:
5
- """ https://medium.com/@r.das699/optimizing-logging-practices-for-streaming-data-in-python-521798e1ed82
6
- """
7
- root_handlers = logging.getLogger().handlers
8
- current_logger = logging.getLogger(name)
9
- if not root_handlers:
10
- new_handler = logging.StreamHandler()
11
- new_handler.terminator = ""
12
- new_handler.setFormatter(logging.Formatter("%(message)s"))
13
- current_logger.addHandler(new_handler)
14
- current_logger.propagate = False
15
- current_logger.setLevel(logging.INFO)
16
- return current_logger
17
-
18
- for handler in current_logger.handlers[:]:
19
- current_logger.removeHandler(handler)
20
-
21
- for handler_r in root_handlers:
22
- if type(handler_r) is logging.StreamHandler:
23
- new_handler = logging.StreamHandler()
24
- new_handler.terminator = ""
25
- new_handler.setFormatter(logging.Formatter("%(message)s"))
26
- current_logger.addHandler(new_handler)
27
- elif type(handler_r) is logging.FileHandler:
28
- new_handler = logging.FileHandler(
29
- handler_r.baseFilename,
30
- handler_r.mode,
31
- handler_r.encoding,
32
- handler_r.delay,
33
- handler_r.errors,
34
- )
35
- new_handler.terminator = "" # This will stop the printing in new line
36
- new_handler.setFormatter(logging.Formatter("%(message)s"))
37
- current_logger.addHandler(new_handler)
38
- else:
39
- continue
40
- current_logger.propagate = False # Don't propagate to root logger
41
- return current_logger
bulk_chain/demo.py DELETED
@@ -1,84 +0,0 @@
1
- import json
2
-
3
- import argparse
4
- import logging
5
- import sys
6
-
7
- from source_iter.service_jsonl import JsonlService
8
-
9
- from bulk_chain.api import init_llm
10
- from bulk_chain.core.service_args import CmdArgsService
11
- from bulk_chain.core.service_json import JsonService
12
- from bulk_chain.core.service_llm import chat_with_lm
13
- from bulk_chain.core.service_schema import SchemaService
14
- from bulk_chain.core.utils import parse_filepath
15
-
16
- logger = logging.getLogger(__name__)
17
- logging.basicConfig(level=logging.INFO)
18
-
19
-
20
- def iter_from_json(filepath):
21
- with open(filepath, "r") as f:
22
- content = json.load(f)
23
- for key, value in content.items():
24
- yield key, value
25
-
26
-
27
- def iter_from_text_file(filepath):
28
- with open(filepath, "r") as f:
29
- yield filepath.split('.')[0], f.read()
30
-
31
-
32
- if __name__ == '__main__':
33
-
34
- parser = argparse.ArgumentParser(description="LLM demo usage based on CoT schema")
35
- parser.add_argument('--adapter', dest='adapter', type=str, default=None)
36
- parser.add_argument('--attempts', dest='attempts', type=int, default=None)
37
- parser.add_argument('--src', dest='src', type=str, nargs="*", default=None)
38
- parser.add_argument('--schema', dest='schema', type=str, default=None,
39
- help="Path to the JSON file that describes schema")
40
- parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
41
- help="Optional trimming prompt by the specified amount of characters.")
42
-
43
- # Extract native arguments.
44
- native_args = CmdArgsService.extract_native_args(sys.argv, end_prefix="%%")
45
- args = parser.parse_args(args=native_args[1:])
46
-
47
- # Extract model-related arguments and Initialize Large Language Model.
48
- model_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%m", end_prefix="%%")
49
- model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": args.attempts}
50
- llm, llm_model_name = init_llm(adapter=args.adapter, **model_args_dict)
51
-
52
- # Setup schema.
53
- schema = SchemaService(json_data=JsonService.read(args.schema))
54
- schema_name = schema.src.get("name", None)
55
- if schema is not None:
56
- logger.info(f"Using schema: {schema_name}")
57
-
58
- output_providers = {
59
- "jsonl": lambda filepath, data_it, header:
60
- JsonlService.write(target=filepath,
61
- data_it=map(lambda item: {key: item[i] for i, key in enumerate(header)}, data_it))
62
- }
63
-
64
- input_file_handlers = {
65
- "json": lambda filepath: iter_from_json(filepath),
66
- "txt": lambda filepath: iter_from_text_file(filepath)
67
- }
68
-
69
- # Input extension type defines the provider.
70
- if args.src is None:
71
- args.src = []
72
- if isinstance(args.src, str):
73
- args.src = [args.src]
74
- sources = [parse_filepath(s) for s in args.src]
75
-
76
- preset_dict = {}
77
- for fp, ext, _ in sources:
78
- for key, value in input_file_handlers[ext](fp):
79
- if key in preset_dict:
80
- raise Exception(f"While at handling {fp}: Key {key} is already registered!")
81
- preset_dict[key] = value
82
-
83
- # Launch Demo.
84
- chat_with_lm(llm, preset_dict=preset_dict, chain=schema.chain, model_name=llm_model_name)
bulk_chain/infer.py DELETED
@@ -1,161 +0,0 @@
1
- from os.path import join, basename
2
-
3
- import argparse
4
- import logging
5
- import sys
6
-
7
- from source_iter.service_csv import CsvService
8
- from source_iter.service_jsonl import JsonlService
9
- from source_iter.service_sqlite import SQLite3Service
10
- from tqdm import tqdm
11
-
12
- from bulk_chain.api import INFER_MODES, _infer_batch, CWD, init_llm
13
- from bulk_chain.core.llm_base import BaseLM
14
- from bulk_chain.core.service_args import CmdArgsService
15
- from bulk_chain.core.service_dict import DictionaryService
16
- from bulk_chain.core.service_json import JsonService
17
- from bulk_chain.core.service_schema import SchemaService
18
- from bulk_chain.core.utils import handle_table_name, optional_limit_iter, parse_filepath
19
-
20
- logger = logging.getLogger(__name__)
21
- logging.basicConfig(level=logging.INFO)
22
-
23
- WRITER_PROVIDERS = {
24
- "sqlite": lambda filepath, table_name, data_it, infer_data_func, **kwargs: SQLite3Service.write(
25
- data_it=data_it, target=filepath, table_name=table_name, data2col_func=infer_data_func,
26
- skip_existed=True, **kwargs)
27
- }
28
-
29
- READER_PROVIDERS = {
30
- "sqlite": lambda filepath, table_name: SQLite3Service.read(filepath, table=table_name)
31
- }
32
-
33
-
34
- def iter_content_cached(input_dicts_it, llm, schema, cache_target, limit_prompt=None, **cache_kwargs):
35
- assert (isinstance(llm, BaseLM))
36
- assert (isinstance(cache_target, str))
37
-
38
- # Quick initialization of the schema.
39
- if isinstance(schema, str):
40
- schema = JsonService.read(schema)
41
- if isinstance(schema, dict):
42
- schema = SchemaService(json_data=schema)
43
-
44
- # Iterator of the queries.
45
- prompts_it = map(
46
- lambda data: DictionaryService.custom_update(src_dict=data, other_dict=schema.cot_args),
47
- input_dicts_it
48
- )
49
-
50
- # Parse target.
51
- cache_filepath, _, cache_table = parse_filepath(filepath=cache_target)
52
-
53
- # Perform caching first.
54
- WRITER_PROVIDERS["sqlite"](
55
- filepath=cache_filepath, table_name=cache_table,
56
- data_it=tqdm(prompts_it, desc="Iter content"),
57
- infer_data_func=lambda c, prompt: _infer_batch(
58
- batch=[prompt], cols=[c],
59
- infer_func=lambda batch: INFER_MODES["default"](llm, batch, limit_prompt),
60
- schema=schema)[0][c],
61
- **cache_kwargs)
62
-
63
- # Then retrieve data.
64
- return READER_PROVIDERS["sqlite"](filepath=cache_filepath, table_name=cache_table)
65
-
66
-
67
- if __name__ == '__main__':
68
-
69
- parser = argparse.ArgumentParser(description="Infer Instruct LLM inference based on CoT schema")
70
- parser.add_argument('--adapter', dest='adapter', type=str, default=None)
71
- parser.add_argument('--id-col', dest='id_col', type=str, default="uid")
72
- parser.add_argument('--src', dest='src', type=str, nargs="?", default=None)
73
- parser.add_argument('--schema', dest='schema', type=str, default=None,
74
- help="Path to the JSON file that describes schema")
75
- parser.add_argument('--to', dest='to', type=str, default=None, choices=["csv", "sqlite"])
76
- parser.add_argument('--output', dest='output', type=str, default=None)
77
- parser.add_argument('--limit', dest='limit', type=int, default=None,
78
- help="Limit amount of source texts for prompting.")
79
- parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
80
- help="Optional trimming prompt by the specified amount of characters.")
81
-
82
- # Extract native arguments.
83
- native_args = CmdArgsService.extract_native_args(sys.argv, end_prefix="%%")
84
- args = parser.parse_args(args=native_args[1:])
85
-
86
- # Extract csv-related arguments.
87
- csv_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%csv", end_prefix="%%")
88
- csv_args_dict = CmdArgsService.args_to_dict(csv_args)
89
-
90
- # Extract model-related arguments and Initialize Large Language Model.
91
- model_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%m", end_prefix="%%")
92
- model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": args.attempts}
93
- llm, llm_model_name = init_llm(adapter=args.adapter, **model_args_dict)
94
-
95
- # Setup schema.
96
- schema = SchemaService(json_data=JsonService.read(args.schema))
97
- schema_name = schema.src.get("name", None)
98
- if schema is not None:
99
- logger.info(f"Using schema: {schema_name}")
100
-
101
- input_providers = {
102
- "csv": lambda filepath: CsvService.read(src=filepath, row_id_key=args.id_col,
103
- as_dict=True, skip_header=True,
104
- delimiter=csv_args_dict.get("delimiter", ","),
105
- escapechar=csv_args_dict.get("escapechar", None)),
106
- "tsv": lambda filepath: CsvService.read(src=filepath, row_id_key=args.id_col,
107
- as_dict=True, skip_header=True,
108
- delimiter=csv_args_dict.get("delimiter", "\t"),
109
- escapechar=csv_args_dict.get("escapechar", None)),
110
- "jsonl": lambda filepath: JsonlService.read(src=filepath, row_id_key=args.id_col)
111
- }
112
-
113
- output_providers = {
114
- "csv": lambda filepath, data_it, header: CsvService.write(target=filepath,
115
- data_it=data_it, header=header,
116
- delimiter=csv_args_dict.get("delimiter", ","),
117
- escapechar=csv_args_dict.get("escapechar", None),
118
- it_type=None),
119
- "tsv": lambda filepath, data_it, header: CsvService.write(target=filepath,
120
- data_it=data_it, header=header,
121
- delimiter=csv_args_dict.get("delimiter", "\t"),
122
- escapechar=csv_args_dict.get("escapechar", None),
123
- it_type=None),
124
- "jsonl": lambda filepath, data_it, header:
125
- JsonlService.write(target=filepath,
126
- data_it=map(lambda item: {key: item[i] for i, key in enumerate(header)}, data_it))
127
- }
128
-
129
- # Setup output.
130
- args.output = args.output.format(model=llm.name()) if args.output is not None else args.output
131
- tgt_filepath, tgt_ext, tgt_meta = parse_filepath(args.output, default_ext=args.to)
132
-
133
- # We do not support multiple files for other modes.
134
- src_filepath, src_ext, src_meta = parse_filepath(args.src)
135
-
136
- def default_output_file_template(ext):
137
- # This is a default template for output files to be generated.
138
- return "".join(["_".join([join(CWD, basename(src_filepath)), llm.name(), schema_name]), ext])
139
-
140
- # Setup cache target as well as the related table.
141
- cache_filepath = default_output_file_template(".sqlite") if tgt_filepath is None else tgt_filepath
142
- cache_table = handle_table_name(tgt_meta if tgt_meta is not None else "contents")
143
-
144
- # This is a content that we extracted via input provider.
145
- it_data = input_providers[src_ext](src_filepath)
146
-
147
- data_it = iter_content_cached(input_dicts_it=optional_limit_iter(it_data=it_data, limit=args.limit),
148
- limit_prompt=args.limit_prompt,
149
- schema=schema,
150
- llm=llm,
151
- id_column_name=args.id_col,
152
- cache_target=":".join([cache_filepath, cache_table]))
153
-
154
- # Setup output target
155
- tgt_ext = src_ext if tgt_ext is None else tgt_ext
156
- output_target = default_output_file_template(f".{tgt_ext}") if tgt_filepath is None else tgt_filepath
157
-
158
- # Perform output writing process.
159
- output_providers[tgt_ext](filepath=output_target,
160
- data_it=data_it,
161
- header=SQLite3Service.read_columns(target=cache_filepath, table=cache_table))
@@ -1,20 +0,0 @@
1
- bulk_chain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- bulk_chain/api.py,sha256=3q1t4A5wop_BRgYanFCCSQBiGu38P9ds0hTbuxNIUKQ,3590
3
- bulk_chain/demo.py,sha256=3mvgEu03EyDDFzXtpx2fxozLITOn9Lo7ati6H1y54S4,3191
4
- bulk_chain/infer.py,sha256=gq6G48XpOK56g5I_AU2kiQirQgcrZ353kfwjjRfQhSo,8069
5
- bulk_chain/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- bulk_chain/core/llm_base.py,sha256=fuWxfEOSRYvnoZMOcfnq1E2LIJKnrpsnxQ1z6SmY1nM,1839
7
- bulk_chain/core/service_args.py,sha256=lq4Veuh4QNu8mlCv8MT9S1rMxTn4FKalyp-3boYonVk,2136
8
- bulk_chain/core/service_batch.py,sha256=yQr6fbQd4ifQBGMhZMrQQeZpXtDchMKMGJi8XPG7thc,1430
9
- bulk_chain/core/service_data.py,sha256=ZjJDtd1jrQm9hRCXMqe4CT_qF2XDbWBE1lVibP7tAWo,942
10
- bulk_chain/core/service_dict.py,sha256=lAghLU-3V3xYGv5BTA327Qcw8UJYmgQRMFdggzlrUgo,383
11
- bulk_chain/core/service_json.py,sha256=6o1xM_8c9QEjH9Q3qEmJylU9nahfRXhUd5sFF2dGJwo,182
12
- bulk_chain/core/service_llm.py,sha256=0lFqX02-BHI9OOdC-7hZhpsb9QrhCbKE7In3jhKXq3I,3452
13
- bulk_chain/core/service_schema.py,sha256=KIP4n0Tz2h1i7SIMGhgAhoiCgUFXOT1rzMt38yACS2U,1154
14
- bulk_chain/core/utils.py,sha256=UV6Cefaw7yZiYblsCr-s9LsbcI83xe7eESBvha9A2Og,2784
15
- bulk_chain/core/utils_logger.py,sha256=BD-ADxaeeuHztaYjqtIY_cIzc5r2Svq9XwRtrgIEqyI,1636
16
- bulk_chain-0.25.2.dist-info/LICENSE,sha256=VF9SjNpwwSSFEY_eP_8A1ocDCrbwfjI1pZexXdCkOwo,1076
17
- bulk_chain-0.25.2.dist-info/METADATA,sha256=-N7-wOVXryBY1jkARSgWYZUAhdLYZlxkJ8qa8Vuj9no,6037
18
- bulk_chain-0.25.2.dist-info/WHEEL,sha256=pL8R0wFFS65tNSRnaOVrsw9EOkOqxLrlUPenUYnJKNo,91
19
- bulk_chain-0.25.2.dist-info/top_level.txt,sha256=Hxq_wyH-GDXKBaA63UfBIiMJO2eCHJG5EOrXDphpeB4,11
20
- bulk_chain-0.25.2.dist-info/RECORD,,