bulk-chain 0.25.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bulk_chain/api.py CHANGED
@@ -1,70 +1,91 @@
1
+ import asyncio
1
2
  import collections
3
+ import logging
2
4
  import os
3
5
  from itertools import chain
4
6
 
5
7
  from bulk_chain.core.llm_base import BaseLM
6
- from bulk_chain.core.service_batch import BatchIterator, BatchService
8
+ from bulk_chain.core.service_asyncio import AsyncioService
9
+ from bulk_chain.core.service_batch import BatchIterator
7
10
  from bulk_chain.core.service_data import DataService
8
11
  from bulk_chain.core.service_dict import DictionaryService
9
12
  from bulk_chain.core.service_json import JsonService
10
13
  from bulk_chain.core.service_schema import SchemaService
11
- from bulk_chain.core.utils import dynamic_init, find_by_prefix
14
+ from bulk_chain.core.utils import attempt_wrapper
12
15
 
13
16
 
14
17
  INFER_MODES = {
15
- "batch": lambda llm, batch, limit_prompt=None: llm.ask_core(
16
- DataService.limit_prompts(batch, limit=limit_prompt))
18
+ "single": lambda llm, batch, **kwargs: [llm.ask(prompt) for prompt in batch],
19
+ "single_stream": lambda llm, batch, **kwargs: [llm.ask_stream(prompt) for prompt in batch],
20
+ "batch": lambda llm, batch, **kwargs: llm.ask(batch),
21
+ "batch_async": lambda llm, batch, **kwargs: AsyncioService.run_tasks(
22
+ batch=batch, async_handler=llm.ask_async, event_loop=kwargs.get("event_loop")
23
+ ),
24
+ "batch_stream_async": lambda llm, batch, **kwargs: AsyncioService.run_tasks(
25
+ batch=batch, async_handler=llm.ask_stream_async, event_loop=kwargs.get("event_loop")
26
+ ),
17
27
  }
18
28
 
19
29
 
20
30
  CWD = os.getcwd()
21
31
 
22
32
 
23
- def _handle_entry(entry, entry_info=None, **kwargs):
33
+ def _iter_batch_prompts(c, batch_content_it, **kwargs):
34
+ for ind_in_batch, entry in enumerate(batch_content_it):
35
+ content = DataService.get_prompt_text(
36
+ prompt=entry[c]["prompt"],
37
+ data_dict=entry,
38
+ handle_missed_func=kwargs["handle_missed_value_func"])
39
+ yield ind_in_batch, content
24
40
 
25
- if isinstance(entry, str):
26
- kwargs.get("callback_str_func", lambda *_: None)(entry, entry_info)
27
- return entry
28
- elif isinstance(entry, collections.abc.Iterable):
29
- chunks = []
30
- h = kwargs.get("callback_stream_func", lambda *_: None)
31
41
 
32
- h(None, entry_info | {"action": "start"})
42
+ def __handle_agen_to_gen(handle, batch, event_loop):
43
+ """ This handler provides conversion of the async generator to generator (sync).
44
+ """
33
45
 
34
- for chunk in map(lambda item: str(item), entry):
35
- chunks.append(chunk)
36
- h(chunk, entry_info)
46
+ def __wrap_with_index(async_gens):
47
+ async def wrapper(index, agen):
48
+ async for item in agen:
49
+ yield index, item
50
+ return [wrapper(i, agen) for i, agen in enumerate(async_gens)]
37
51
 
38
- h(None, entry_info | {"action": "end"})
52
+ agen_list = handle(batch, event_loop=event_loop)
39
53
 
40
- return "".join(chunks)
54
+ it = AsyncioService.async_gen_to_iter(
55
+ gen=AsyncioService.merge_generators(*__wrap_with_index(agen_list)),
56
+ loop=event_loop)
41
57
 
42
- raise Exception(f"Non supported type `{type(entry)}` for handling output from batch")
58
+ for ind_in_batch, chunk in it:
59
+ yield ind_in_batch, str(chunk)
43
60
 
44
61
 
45
- def _update_batch_content(c, batch, schema, **kwargs):
46
- assert (isinstance(batch, list))
47
- assert (isinstance(c, str))
48
-
49
- if c in schema.p2r:
50
- for batch_item in batch:
51
- batch_item[c] = DataService.get_prompt_text(
52
- prompt=batch_item[c]["prompt"],
53
- data_dict=batch_item,
54
- handle_missed_func=kwargs["handle_missed_value_func"])
55
- if c in schema.r2p:
56
- p_column = schema.r2p[c]
57
- # This instruction takes a lot of time in a non-batching mode.
58
- BatchService.handle_param_as_batch(
59
- batch=batch,
60
- src_param=p_column,
61
- tgt_param=c,
62
- handle_batch_func=lambda b: kwargs["handle_batch_func"](b),
63
- handle_entry_func=lambda entry, info: _handle_entry(entry=entry, entry_info=info, **kwargs)
64
- )
65
-
66
-
67
- def _infer_batch(batch, schema, cols=None, **kwargs):
62
+ def __handle_gen(handle, batch, event_loop):
63
+ """ This handler deals with the iteration of each individual element of the batch.
64
+ """
65
+
66
+ def _iter_entry_content(entry):
67
+ if isinstance(entry, str):
68
+ yield entry
69
+ elif isinstance(entry, collections.abc.Iterable):
70
+ for chunk in map(lambda item: str(item), entry):
71
+ yield chunk
72
+ else:
73
+ raise Exception(f"Non supported type `{type(entry)}` for handling output from batch")
74
+
75
+ for ind_in_batch, entry in enumerate(handle(batch, event_loop=event_loop)):
76
+ for chunk in _iter_entry_content(entry=entry):
77
+ yield ind_in_batch, chunk
78
+
79
+
80
+ def _iter_chunks(p_column, batch_content_it, **kwargs):
81
+ handler = __handle_agen_to_gen if kwargs["infer_mode"] == "batch_stream_async" else __handle_gen
82
+ p_batch = [item[p_column] for item in batch_content_it]
83
+ it = handler(handle=kwargs["handle_batch_func"], batch=p_batch, event_loop=kwargs["event_loop"])
84
+ for ind_in_batch, chunk in it:
85
+ yield ind_in_batch, chunk
86
+
87
+
88
+ def _infer_batch(batch, batch_ind, schema, return_mode, cols=None, **kwargs):
68
89
  assert (isinstance(batch, list))
69
90
 
70
91
  if len(batch) == 0:
@@ -75,18 +96,54 @@ def _infer_batch(batch, schema, cols=None, **kwargs):
75
96
  cols = list(first_item.keys()) if cols is None else cols
76
97
 
77
98
  for c in cols:
78
- _update_batch_content(c=c, batch=batch, schema=schema, **kwargs)
79
-
80
- return batch
81
99
 
82
-
83
- def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, limit_prompt=None, **kwargs):
100
+ # Handling prompt column.
101
+ if c in schema.p2r:
102
+ content_it = _iter_batch_prompts(c=c, batch_content_it=iter(batch), **kwargs)
103
+ for ind_in_batch, prompt in content_it:
104
+ batch[ind_in_batch][c] = prompt
105
+
106
+ # Handling column for inference.
107
+ if c in schema.r2p:
108
+ content_it = _iter_chunks(p_column=schema.r2p[c], batch_content_it=iter(batch), **kwargs)
109
+ # Register values.
110
+ for item in batch:
111
+ item[c] = []
112
+ for ind_in_batch, chunk in content_it:
113
+ # Append batch.
114
+ batch[ind_in_batch][c].append(chunk)
115
+ # Returning (optional).
116
+ if return_mode == "chunk":
117
+ global_ind = batch_ind * len(batch) + ind_in_batch
118
+ yield [global_ind, c, chunk]
119
+
120
+ # Convert content to string.
121
+ for item in batch:
122
+ item[c] = "".join(item[c])
123
+
124
+ if return_mode == "record":
125
+ for record in batch:
126
+ yield record
127
+
128
+ if return_mode == "batch":
129
+ yield batch
130
+
131
+
132
+ def iter_content(input_dicts_it, llm, schema, batch_size=1, limit_prompt=None,
133
+ infer_mode="batch", return_mode="batch", attempts=1, event_loop=None,
134
+ **kwargs):
84
135
  """ This method represent Python API aimed at application of `llm` towards
85
136
  iterator of input_dicts via cache_target that refers to the SQLite using
86
137
  the given `schema`
87
138
  """
139
+ assert (infer_mode in INFER_MODES.keys())
140
+ assert (return_mode in ["batch", "chunk", "record"])
88
141
  assert (isinstance(llm, BaseLM))
89
142
 
143
+ # Setup event loop.
144
+ event_loop = asyncio.get_event_loop_policy().get_event_loop() \
145
+ if event_loop is None else event_loop
146
+
90
147
  # Quick initialization of the schema.
91
148
  if isinstance(schema, str):
92
149
  schema = JsonService.read(schema)
@@ -94,35 +151,36 @@ def iter_content(input_dicts_it, llm, schema, batch_size=1, return_batch=True, l
94
151
  schema = SchemaService(json_data=schema)
95
152
 
96
153
  prompts_it = map(
97
- lambda data: DictionaryService.custom_update(src_dict=data, other_dict=schema.cot_args),
154
+ lambda data: DictionaryService.custom_update(src_dict=dict(data), other_dict=schema.cot_args),
98
155
  input_dicts_it
99
156
  )
100
157
 
158
+ handle_batch_func = lambda batch, **handle_kwargs: INFER_MODES[infer_mode](
159
+ llm,
160
+ DataService.limit_prompts(batch, limit=limit_prompt),
161
+ **handle_kwargs
162
+ )
163
+
164
+ # Optional wrapping into attempts.
165
+ if attempts > 1:
166
+ # Optional setup of the logger.
167
+ logger = logging.getLogger(__name__)
168
+ logging.basicConfig(level=logging.INFO)
169
+
170
+ attempt_dec = attempt_wrapper(attempts=attempts,
171
+ delay_sec=kwargs.get("attempt_delay_sec", 1),
172
+ logger=logger)
173
+ handle_batch_func = attempt_dec(handle_batch_func)
174
+
101
175
  content_it = (_infer_batch(batch=batch,
102
- handle_batch_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
176
+ batch_ind=batch_ind,
177
+ infer_mode=infer_mode,
178
+ handle_batch_func=handle_batch_func,
179
+ handle_missed_value_func=lambda *_: None,
180
+ return_mode=return_mode,
103
181
  schema=schema,
182
+ event_loop=event_loop,
104
183
  **kwargs)
105
- for batch in BatchIterator(prompts_it, batch_size=batch_size))
106
-
107
- yield from content_it if return_batch else chain.from_iterable(content_it)
184
+ for batch_ind, batch in enumerate(BatchIterator(prompts_it, batch_size=batch_size)))
108
185
 
109
-
110
- def init_llm(adapter, **model_kwargs):
111
- """ This method perform dynamic initialization of LLM from third-party resource.
112
- """
113
- assert (isinstance(adapter, str))
114
-
115
- # List of the Supported models and their API wrappers.
116
- models_preset = {
117
- "dynamic": lambda: dynamic_init(class_dir=CWD, class_filepath=llm_model_name,
118
- class_name=llm_model_params)(**model_kwargs)
119
- }
120
-
121
- # Initialize LLM model.
122
- params = adapter.split(':')
123
- llm_model_type = params[0]
124
- llm_model_name = params[1] if len(params) > 1 else params[-1]
125
- llm_model_params = ':'.join(params[2:]) if len(params) > 2 else None
126
- llm = find_by_prefix(d=models_preset, key=llm_model_type)()
127
-
128
- return llm, llm_model_name
186
+ yield from chain.from_iterable(content_it)
@@ -1,52 +1,24 @@
1
- import logging
2
- import time
3
-
4
- from bulk_chain.core.utils import format_model_name
5
-
6
-
7
1
  class BaseLM(object):
8
2
 
9
- def __init__(self, name=None, attempts=None, delay_sec=1, enable_log=True,
10
- support_batching=False, **kwargs):
11
-
12
- self.__name = name
13
- self.__attempts = 1 if attempts is None else attempts
14
- self.__delay_sec = delay_sec
15
- self.__support_batching = support_batching
16
-
17
- if enable_log:
18
- self.__logger = logging.getLogger(__name__)
19
- logging.basicConfig(level=logging.INFO)
20
-
21
- def ask_core(self, batch):
22
-
23
- for i in range(self.__attempts):
24
- try:
25
- if self.__support_batching:
26
- # Launch in batch mode.
27
- content = batch
28
- else:
29
- # Launch in non-batch mode.
30
- assert len(batch) == 1, "The LM does not support batching," \
31
- f" while size of the content is {len(batch)} which is not equal 1. " \
32
- f"Please enable batch-supporting or set required inference settings."
33
- content = batch[0]
3
+ def __init__(self, **kwargs):
4
+ pass
34
5
 
35
- response = self.ask(content)
36
-
37
- # Wrapping into batch the response in the case of non-batching mode.
38
- return response if self.__support_batching else [response]
39
-
40
- except Exception as e:
41
- if self.__logger is not None:
42
- self.__logger.info("Unable to infer the result. Try {} out of {}.".format(i, self.__attempts))
43
- self.__logger.info(e)
44
- time.sleep(self.__delay_sec)
6
+ def ask(self, content):
7
+ """ Assumes to return str.
8
+ """
9
+ raise NotImplemented()
45
10
 
46
- raise Exception("Can't infer")
11
+ def ask_stream(self, content):
12
+ """ Assumes to return generator.
13
+ """
14
+ raise NotImplemented()
47
15
 
48
- def ask(self, content):
16
+ async def ask_async(self, prompt):
17
+ """ Assumes to return co-routine.
18
+ """
49
19
  raise NotImplemented()
50
20
 
51
- def name(self):
52
- return format_model_name(self.__name)
21
+ async def ask_stream_async(self, batch):
22
+ """ Assumes to return AsyncGenerator.
23
+ """
24
+ raise NotImplemented()
@@ -0,0 +1,65 @@
1
+ import asyncio
2
+ from typing import AsyncGenerator, Any
3
+
4
+
5
+ class AsyncioService:
6
+
7
+ @staticmethod
8
+ async def _run_tasks_async(batch, async_handler):
9
+ tasks = [async_handler(prompt) for prompt in batch]
10
+ return await asyncio.gather(*tasks)
11
+
12
+ @staticmethod
13
+ async def _run_generator(gen, output_queue, idx):
14
+ try:
15
+ async for item in gen:
16
+ await output_queue.put((idx, item))
17
+ finally:
18
+ await output_queue.put((idx, StopAsyncIteration))
19
+
20
+
21
+ @staticmethod
22
+ def run_tasks(event_loop, **tasks_kwargs):
23
+ return event_loop.run_until_complete(AsyncioService._run_tasks_async(**tasks_kwargs))
24
+
25
+ @staticmethod
26
+ async def merge_generators(*gens: AsyncGenerator[Any, None]) -> AsyncGenerator[Any, None]:
27
+
28
+ output_queue = asyncio.Queue()
29
+ tasks = [
30
+ asyncio.create_task(AsyncioService._run_generator(gen, output_queue, idx))
31
+ for idx, gen in enumerate(gens)
32
+ ]
33
+
34
+ finished = set()
35
+ while len(finished) < len(tasks):
36
+ idx, item = await output_queue.get()
37
+ if item is StopAsyncIteration:
38
+ finished.add(idx)
39
+ else:
40
+ yield item
41
+
42
+ for task in tasks:
43
+ task.cancel()
44
+
45
+ @staticmethod
46
+ def async_gen_to_iter(gen, loop=None):
47
+ """ This approach is limited. Could be considered as legacy.
48
+ https://stackoverflow.com/questions/71580727/translating-async-generator-into-sync-one/78573267#78573267
49
+ """
50
+
51
+ loop_created = False
52
+ if loop is None:
53
+ loop_created = True
54
+ loop = asyncio.new_event_loop()
55
+
56
+ asyncio.set_event_loop(loop)
57
+ try:
58
+ while True:
59
+ try:
60
+ yield loop.run_until_complete(gen.__anext__())
61
+ except StopAsyncIteration:
62
+ break
63
+ finally:
64
+ if loop_created:
65
+ loop.close()
@@ -1,27 +1,8 @@
1
- class BatchService(object):
2
-
3
- @staticmethod
4
- def handle_param_as_batch(batch, src_param, tgt_param, handle_batch_func, handle_entry_func):
5
- assert (isinstance(batch, list))
6
- assert (isinstance(src_param, str))
7
- assert (callable(handle_batch_func))
8
-
9
- _batch = [item[src_param] for item in batch]
10
-
11
- # Do handling for the batch.
12
- _handled_batch = handle_batch_func(_batch)
13
- assert (isinstance(_handled_batch, list))
14
-
15
- # Apply changes.
16
- for i, item in enumerate(batch):
17
- item[tgt_param] = handle_entry_func(entry=_handled_batch[i], info={"ind": i, "param": tgt_param})
18
-
19
-
20
1
  class BatchIterator:
21
2
 
22
3
  def __init__(self, data_iter, batch_size, end_value=None, filter_func=None):
23
- assert(isinstance(batch_size, int) and batch_size > 0)
24
- assert(callable(end_value) or end_value is None)
4
+ assert (isinstance(batch_size, int) and batch_size > 0)
5
+ assert (callable(end_value) or end_value is None)
25
6
  self.__data_iter = data_iter
26
7
  self.__index = 0
27
8
  self.__batch_size = batch_size
bulk_chain/core/utils.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import importlib
2
2
  import logging
3
3
  import sys
4
+ import time
4
5
  from collections import Counter
5
6
  from os.path import dirname, join, basename
6
7
 
@@ -48,28 +49,6 @@ def iter_params(text):
48
49
  beg = pe+1
49
50
 
50
51
 
51
- def format_model_name(name):
52
- return name.replace("/", "_")
53
-
54
-
55
- def parse_filepath(filepath, default_filepath=None, default_ext=None):
56
- """ This is an auxiliary function for handling sources and targets from cmd string.
57
- """
58
- if filepath is None:
59
- return default_filepath, default_ext, None
60
- info = filepath.split(":")
61
- filepath = info[0]
62
- meta = info[1] if len(info) > 1 else None
63
- ext = filepath.split('.')[-1] if default_ext is None else default_ext
64
- return filepath, ext, meta
65
-
66
-
67
- def handle_table_name(name):
68
- return name.\
69
- replace('-', '_').\
70
- replace('.', "_")
71
-
72
-
73
52
  def auto_import(name, is_class=False):
74
53
  """ Import from the external python packages.
75
54
  """
@@ -82,10 +61,10 @@ def auto_import(name, is_class=False):
82
61
  return m() if is_class else m
83
62
 
84
63
 
85
- def dynamic_init(class_dir, class_filepath, class_name=None):
64
+ def dynamic_init(class_filepath, class_name=None):
86
65
 
87
66
  # Registering path.
88
- target = join(class_dir, dirname(class_filepath))
67
+ target = join(dirname(class_filepath))
89
68
  logger.info(f"Adding sys path for `{target}`")
90
69
  sys.path.insert(1, target)
91
70
  class_path_list = class_filepath.split('/')
@@ -111,3 +90,21 @@ def optional_limit_iter(it_data, limit=None):
111
90
  if limit is not None and counter["returned"] > limit:
112
91
  break
113
92
  yield data
93
+
94
+
95
+ def attempt_wrapper(attempts, delay_sec=1, logger=None):
96
+ def decorator(func):
97
+ def wrapper(*args, **kwargs):
98
+ for i in range(attempts):
99
+ try:
100
+ # Do action.
101
+ return func(*args, **kwargs)
102
+ except Exception as e:
103
+ if logger is not None:
104
+ logger.info(f"Unable to infer the result. Try {i} out of {attempts}.")
105
+ logger.info(e)
106
+ if delay_sec is not None:
107
+ time.sleep(delay_sec)
108
+ raise Exception(f"Failed after {attempts} attempts")
109
+ return wrapper
110
+ return decorator
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bulk_chain
3
- Version: 0.25.3
3
+ Version: 1.1.0
4
4
  Summary: A lightweight, no-strings-attached Chain-of-Thought framework for your LLM, ensuring reliable results for bulk input requests.
5
5
  Home-page: https://github.com/nicolay-r/bulk-chain
6
6
  Author: Nicolay Rusnachenko
@@ -15,10 +15,8 @@ Classifier: Topic :: Text Processing :: Linguistic
15
15
  Requires-Python: >=3.6
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
- Requires-Dist: tqdm
19
- Requires-Dist: source-iter ==0.24.3
20
18
 
21
- # bulk-chain 0.25.3
19
+ # bulk-chain 1.1.0
22
20
  ![](https://img.shields.io/badge/Python-3.9-brightgreen.svg)
23
21
  [![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nicolay-r/bulk-chain/blob/master/bulk_chain_tutorial.ipynb)
24
22
  [![twitter](https://img.shields.io/twitter/url/https/shields.io.svg?style=social)](https://x.com/nicolayr_/status/1847969224636961033)
@@ -31,7 +29,7 @@ Requires-Dist: source-iter ==0.24.3
31
29
  <p align="center">
32
30
  <a href="https://github.com/nicolay-r/nlp-thirdgate?tab=readme-ov-file#llm"><b>Third-party providers hosting</b>↗️</a>
33
31
  <br>
34
- <a href="https://github.com/nicolay-r/bulk-chain/blob/master/README.md#demo-mode">👉<b>demo</b>👈</a>
32
+ <a href="https://github.com/nicolay-r/bulk-chain-shell">👉<b>demo</b>👈</a>
35
33
  </p>
36
34
 
37
35
  A no-strings-attached **framework** for your LLM that allows applying Chain-of-Thought-alike [prompt `schema`](#chain-of-thought-schema) towards a massive textual collections using custom **[third-party providers ↗️](https://github.com/nicolay-r/nlp-thirdgate?tab=readme-ov-file#llm)**.
@@ -39,11 +37,7 @@ A no-strings-attached **framework** for your LLM that allows applying Chain-of-
39
37
  ### Main Features
40
38
  * ✅ **No-strings**: you're free to LLM dependencies and flexible `venv` customization.
41
39
  * ✅ **Support schemas descriptions** for Chain-of-Thought concept.
42
- * ✅ **Provides iterator over infinite amount of input contexts** served in `CSV`/`JSONL`.
43
-
44
- ### Extra Features
45
- * ✅ **Progress caching [for remote LLMs]**: withstanding exception during LLM calls by using `sqlite3` engine for caching LLM answers;
46
-
40
+ * ✅ **Provides iterator over infinite amount of input contexts**
47
41
 
48
42
  # Installation
49
43
 
@@ -83,60 +77,37 @@ Below, is an example on how to declare your own schema:
83
77
 
84
78
  # Usage
85
79
 
86
- Preliminary steps:
87
-
88
- 1. Define your [schema](#chain-of-thought-schema) ([Example for Sentiment Analysis](/ext/schema/thor_cot_schema.json)))
89
- 2. Wrap or pick **LLM model** from the [<b>Third-party providers hosting</b>↗️](https://github.com/nicolay-r/nlp-thirdgate?tab=readme-ov-file#llm).
90
-
91
- ## Shell
92
-
93
- ### Demo Mode
94
-
95
- **demo mode** to interact with LLM via command line with LLM output streaming support.
96
- The video below illustrates an example of application for sentiment analysis on author opinion extraction towards mentioned object in text.
97
-
98
- Quck start with launching demo:
99
- 1. ⬇️ Download [replicate](https://replicate.com/) provider for `bulk-chain`:
100
- 2. 📜 Setup your reasoning `thor_cot_schema.json` according to the [following example ↗️](test/schema/thor_cot_schema.json)
101
- 3. 🚀 Launch `demo.py` as follows:
102
- ```bash
103
- python3 -m bulk_chain.demo \
104
- --schema "test/schema/thor_cot_schema.json" \
105
- --adapter "dynamic:replicate_104.py:Replicate" \
106
- %%m \
107
- --model_name "meta/meta-llama-3-70b-instruct" \
108
- --api_token "<REPLICATE-API-TOKEN>" \
109
- --stream
110
- ```
111
-
112
- 📺 This video showcase application of the [↗️ Sentiment Analysis Schema](https://github.com/nicolay-r/bulk-chain/blob/master/test/schema/thor_cot_schema.json) towards [LLaMA-3-70B-Instruct](https://replicate.com/meta/meta-llama-3-70b-instruct) hosted by Replicate for reasoning over submitted texts
113
- ![sa-bulk-chain-cot-final](https://github.com/user-attachments/assets/0cc8fdcb-6ddb-44a3-8f05-d76250ae6423)
80
+ ## 🤖 Prepare
114
81
 
82
+ 1. [schema](#chain-of-thought-schema)
83
+ * [Example for Sentiment Analysis](test/schema/thor_cot_schema.json)
84
+ 2. **LLM model** from the [<b>Third-party providers hosting</b>↗️](https://github.com/nicolay-r/nlp-thirdgate?tab=readme-ov-file#llm).
85
+ 3. Data (iter of dictionaries)
115
86
 
116
- ### Inference Mode
87
+ ## 🚀 Launch
117
88
 
118
- > **NOTE:** You have to install `source-iter` and `tqdm` packages that actual [dependencies](dependencies.txt) of this project
89
+ > **API**: For more details see the [**related Wiki page**](https://github.com/nicolay-r/bulk-chain/wiki)
119
90
 
120
- 1. ⬇️ Download [replicate](https://replicate.com/) provider for `bulk-chain`:
121
- ```bash
122
- wget https://raw.githubusercontent.com/nicolay-r/nlp-thirdgate/refs/heads/master/llm/replicate_104.py
123
- ```
124
- 2. 📜 Setup your reasoning `schema.json` according to the [following example ↗️](test/schema/default.json)
125
- 3. 🚀 Launch inference using `DeepSeek-R1`:
126
- ```bash
127
- python3 -m bulk_chain.infer \
128
- --src "<PATH-TO-YOUR-CSV-or-JSONL>" \
129
- --schema "test/schema/default.json" \
130
- --adapter "replicate_104.py:Replicate" \
131
- %%m \
132
- --model_name "deepseek-ai/deepseek-r1" \
133
- --api_token "<REPLICATE-API-TOKEN>"
91
+ ```python
92
+ from bulk_chain.core.utils import dynamic_init
93
+ from bulk_chain.api import iter_content
94
+
95
+ content_it = iter_content(
96
+ # 1. Your schema.
97
+ schema="YOUR_SCHEMA.json",
98
+ # 2. Your third-party model implementation.
99
+ llm=dynamic_init(class_filepath="replicate_104.py", class_name="Replicate")(api_token="<API-KEY>"),
100
+ # 3. Customize your inference and result providing modes:
101
+ infer_mode="batch_async",
102
+ return_mode="batch",
103
+ # 4. Your iterator of dictionaries
104
+ input_dicts_it=YOUR_DATA_IT,
105
+ )
106
+
107
+ for content in content_it:
108
+ # Handle your LLM responses here ...
134
109
  ```
135
110
 
136
- ## API
137
-
138
- Please take a look at the [**related Wiki page**](https://github.com/nicolay-r/bulk-chain/wiki)
139
-
140
111
 
141
112
  # Embed your LLM
142
113
 
@@ -0,0 +1,16 @@
1
+ bulk_chain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ bulk_chain/api.py,sha256=gPGjaHYIn2Ewn6yXIXER-CM5SgXQ3ZJH-SdRyaPDOo0,6890
3
+ bulk_chain/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ bulk_chain/core/llm_base.py,sha256=aa73TGW03yLXMHY4b_1NgquRvP0CzH8IWZkcFPABFUg,557
5
+ bulk_chain/core/service_asyncio.py,sha256=S-D4K3LBa3noKTm0tXazluYVI8cBgN1IB6v6MFoMyNQ,1972
6
+ bulk_chain/core/service_batch.py,sha256=lWmjO0aU6h2rmfx_kGmNqt0Rdeaf2a4Dn5VyfKFkfDs,1033
7
+ bulk_chain/core/service_data.py,sha256=OWWHHnr_plwxYTxLuvMrhEc1PbSx-XC3rbFzV0hy3vk,1107
8
+ bulk_chain/core/service_dict.py,sha256=lAghLU-3V3xYGv5BTA327Qcw8UJYmgQRMFdggzlrUgo,383
9
+ bulk_chain/core/service_json.py,sha256=6o1xM_8c9QEjH9Q3qEmJylU9nahfRXhUd5sFF2dGJwo,182
10
+ bulk_chain/core/service_schema.py,sha256=KIP4n0Tz2h1i7SIMGhgAhoiCgUFXOT1rzMt38yACS2U,1154
11
+ bulk_chain/core/utils.py,sha256=tp1FJQBmJt-3QmG7B0hyJNTFyg_8BwTTdl8xTxSgNDk,3140
12
+ bulk_chain-1.1.0.dist-info/LICENSE,sha256=VF9SjNpwwSSFEY_eP_8A1ocDCrbwfjI1pZexXdCkOwo,1076
13
+ bulk_chain-1.1.0.dist-info/METADATA,sha256=EheCGDisKF0TwmzJfnDxW-rgsDVPNpCYGOvuaDn91tw,4428
14
+ bulk_chain-1.1.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
15
+ bulk_chain-1.1.0.dist-info/top_level.txt,sha256=Hxq_wyH-GDXKBaA63UfBIiMJO2eCHJG5EOrXDphpeB4,11
16
+ bulk_chain-1.1.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.1.3)
2
+ Generator: setuptools (75.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,127 +0,0 @@
1
- import sqlite3
2
-
3
-
4
- class SQLite3Service(object):
5
-
6
- @staticmethod
7
- def __create_table(table_name, columns, id_column_name,
8
- id_column_type, cur, sqlite3_column_types=None):
9
-
10
- # Setting up default column types.
11
- if sqlite3_column_types is None:
12
- types_count = len(columns) if id_column_name in columns else len(columns) - 1
13
- sqlite3_column_types = ["TEXT"] * types_count
14
-
15
- # Provide the ID column.
16
- sqlite3_column_types = [id_column_type] + sqlite3_column_types
17
-
18
- # Compose the whole columns list.
19
- content = ", ".join([f"[{item[0]}] {item[1]}" for item in zip(columns, sqlite3_column_types)])
20
- cur.execute(f"CREATE TABLE IF NOT EXISTS {table_name}({content})")
21
- cur.execute(f"CREATE INDEX IF NOT EXISTS [{id_column_name}] ON {table_name}([{id_column_name}])")
22
-
23
- @staticmethod
24
- def __it_row_lists(cursor):
25
- for row in cursor:
26
- yield row
27
-
28
- @staticmethod
29
- def create_table_if_not_exist(**kwargs):
30
- return SQLite3Service.__create_table(**kwargs)
31
-
32
- @staticmethod
33
- def entry_exist(table_name, target, id_column_name, id_value, **connect_kwargs) -> bool:
34
- with sqlite3.connect(target, **connect_kwargs) as con:
35
- cursor = con.cursor()
36
-
37
- # Check table existance.
38
- query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
39
- cursor.execute(query, (table_name,))
40
- if cursor.fetchone() is None:
41
- return False
42
-
43
- # Check element.
44
- r = cursor.execute(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE [{id_column_name}]='{id_value}');")
45
- ans = r.fetchone()[0]
46
- return ans == 1
47
-
48
- @staticmethod
49
- def write(data_it, target, table_name, columns=None, id_column_name="id", data2col_func=None,
50
- id_column_type="INTEGER", sqlite3_column_types=None, it_type='dict',
51
- create_table_if_not_exist=True, skip_existed=True, **connect_kwargs):
52
-
53
- need_set_column_id = True
54
- need_initialize_columns = columns is None
55
-
56
- # Setup default columns.
57
- columns = [] if columns is None else columns
58
-
59
- with sqlite3.connect(target, **connect_kwargs) as con:
60
- cur = con.cursor()
61
-
62
- for content in data_it:
63
-
64
- if it_type == 'dict':
65
- # Extracting columns from data.
66
- data = content
67
- uid = data[id_column_name]
68
- row_columns = list(data.keys())
69
- row_params_func = lambda: [data2col_func(c, data) if data2col_func is not None else data[c]
70
- for c in row_columns]
71
- # Append columns if needed.
72
- if need_initialize_columns:
73
- columns = list(row_columns)
74
- elif it_type is None:
75
- # Setup row columns.
76
- uid, data = content
77
- row_columns = columns
78
- row_params_func = lambda: [uid] + data
79
- else:
80
- raise Exception(f"it_type {it_type} does not supported!")
81
-
82
- if need_set_column_id:
83
- # Register ID column.
84
- if id_column_name not in columns:
85
- columns.append(id_column_name)
86
- # Place ID column first.
87
- columns.insert(0, columns.pop(columns.index(id_column_name)))
88
- need_set_column_id = False
89
-
90
- if create_table_if_not_exist:
91
- SQLite3Service.__create_table(
92
- columns=columns, table_name=table_name, cur=cur,
93
- id_column_name=id_column_name, id_column_type=id_column_type,
94
- sqlite3_column_types=sqlite3_column_types)
95
-
96
- # Check that each rows satisfies criteria of the first row.
97
- [Exception(f"{column} is expected to be in row!") for column in row_columns if column not in columns]
98
-
99
- if skip_existed:
100
- r = cur.execute(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE [{id_column_name}]='{uid}');")
101
- ans = r.fetchone()[0]
102
-
103
- if ans == 1:
104
- continue
105
-
106
- params = ", ".join(tuple(['?'] * (len(columns))))
107
- row_columns_str = ", ".join([f"[{col}]" for col in row_columns])
108
- content_list = row_params_func()
109
- cur.execute(f"INSERT INTO {table_name}({row_columns_str}) VALUES ({params})", content_list)
110
- con.commit()
111
-
112
- cur.close()
113
-
114
- @staticmethod
115
- def read(src, table="content", **connect_kwargs):
116
- with sqlite3.connect(src, **connect_kwargs) as conn:
117
- cursor = conn.cursor()
118
- cursor.execute(f"SELECT * FROM {table}")
119
- for record_list in SQLite3Service.__it_row_lists(cursor):
120
- yield record_list
121
-
122
- @staticmethod
123
- def read_columns(target, table="content", **connect_kwargs):
124
- with sqlite3.connect(target, **connect_kwargs) as conn:
125
- cursor = conn.cursor()
126
- cursor.execute(f"PRAGMA table_info({table})")
127
- return [row[1] for row in cursor.fetchall()]
@@ -1,72 +0,0 @@
1
- class CmdArgsService:
2
-
3
- @staticmethod
4
- def autocast(v):
5
- for t in [int, float, str]:
6
- try:
7
- return t(v)
8
- except:
9
- pass
10
-
11
- @staticmethod
12
- def iter_arguments(lst):
13
-
14
- def __release():
15
-
16
- # We use the True value by default to treat the related parameter as flag.
17
- if len(buf) == 0:
18
- buf.append(True)
19
-
20
- return key, buf if len(buf) > 1 else buf[0]
21
-
22
- key = None
23
- buf = []
24
- for a in lst:
25
- if a.startswith('--'):
26
- # release
27
- if key is not None:
28
- yield __release()
29
- # set new key and empty buf
30
- key = a[2:]
31
- buf = []
32
- else:
33
- # append argument into buffer.
34
- buf.append(a)
35
-
36
- # Sharing the remaining params.
37
- if key is not None:
38
- yield __release()
39
-
40
- @staticmethod
41
- def __find_suffix_ind(lst, idx_from, end_prefix):
42
- for i in range(idx_from, len(lst)):
43
- if lst[i].startswith(end_prefix):
44
- return i
45
- return len(lst)
46
-
47
- @staticmethod
48
- def extract_native_args(lst, end_prefix):
49
- return lst[:CmdArgsService.__find_suffix_ind(lst, idx_from=0, end_prefix=end_prefix)]
50
-
51
- @staticmethod
52
- def find_grouped_args(lst, starts_with, end_prefix):
53
- """Slices a list in two, cutting on index matching "sep"
54
- """
55
-
56
- # Checking the presence of starts_with.
57
- # We have to return empty content in the case of absence starts_with in the lst.
58
- if starts_with not in lst:
59
- return []
60
-
61
- # Assigning start index.
62
- idx_from = lst.index(starts_with) + 1
63
-
64
- # Assigning end index.
65
- idx_to = CmdArgsService.__find_suffix_ind(lst, idx_from=idx_from, end_prefix=end_prefix)
66
-
67
- return lst[idx_from:idx_to]
68
-
69
- @staticmethod
70
- def args_to_dict(args):
71
- return {k: CmdArgsService.autocast(v) if not isinstance(v, list) else v
72
- for k, v in CmdArgsService.iter_arguments(args)} if args is not None else {}
@@ -1,68 +0,0 @@
1
- from bulk_chain.api import iter_content
2
- from bulk_chain.core.llm_base import BaseLM
3
- from bulk_chain.core.utils_logger import StreamedLogger
4
-
5
-
6
- def pad_str(text, pad):
7
- return text.rjust(len(text) + pad, ' ')
8
-
9
-
10
- def nice_output(text, remove_new_line=False):
11
- short_text = text.replace("\n", "") if remove_new_line else text
12
- return short_text
13
-
14
-
15
- def chat_with_lm(lm, preset_dict=None, schema=None, model_name=None, pad=0):
16
- assert (isinstance(lm, BaseLM))
17
- assert (isinstance(model_name, str) or model_name is None)
18
-
19
- preset_dict = {} if preset_dict is None else preset_dict
20
-
21
- streamed_logger = StreamedLogger(__name__)
22
- do_exit = False
23
- model_name = model_name if model_name is not None else "agent"
24
-
25
- while not do_exit:
26
-
27
- # Launching the CoT engine loop.
28
- data_dict = {} | preset_dict
29
-
30
- def callback_str_func(entry, info):
31
- streamed_logger.info(pad_str(f"{model_name} ({info['param']})->\n", pad=pad))
32
- streamed_logger.info(nice_output(entry, remove_new_line=False))
33
- streamed_logger.info("\n\n")
34
-
35
- def handle_missed_value(col_name):
36
- user_input = input(f"Enter your prompt for `{col_name}`"
37
- f"(or 'exit' to quit): ")
38
-
39
- if user_input.lower() == 'exit':
40
- exit(0)
41
-
42
- return user_input
43
-
44
- def callback_stream_func(entry, info):
45
- if entry is None and info["action"] == "start":
46
- streamed_logger.info(pad_str(f"{model_name} ({info['param']})->\n", pad=pad))
47
- return
48
- if entry is None and info["action"] == "end":
49
- streamed_logger.info("\n\n")
50
- return
51
-
52
- streamed_logger.info(entry)
53
-
54
- content_it = iter_content(
55
- input_dicts_it=[data_dict],
56
- llm=lm,
57
- schema=schema,
58
- batch_size=1,
59
- return_batch=True,
60
- handle_missed_value_func=handle_missed_value,
61
- callback_str_func=callback_str_func,
62
- callback_stream_func=callback_stream_func,
63
- )
64
-
65
- for _ in content_it:
66
- user_input = input(f"Enter to continue (or 'exit' to quit) ...\n")
67
- if user_input.lower() == 'exit':
68
- do_exit = True
@@ -1,41 +0,0 @@
1
- import logging
2
-
3
-
4
- def StreamedLogger(name: str) -> logging.Logger:
5
- """ https://medium.com/@r.das699/optimizing-logging-practices-for-streaming-data-in-python-521798e1ed82
6
- """
7
- root_handlers = logging.getLogger().handlers
8
- current_logger = logging.getLogger(name)
9
- if not root_handlers:
10
- new_handler = logging.StreamHandler()
11
- new_handler.terminator = ""
12
- new_handler.setFormatter(logging.Formatter("%(message)s"))
13
- current_logger.addHandler(new_handler)
14
- current_logger.propagate = False
15
- current_logger.setLevel(logging.INFO)
16
- return current_logger
17
-
18
- for handler in current_logger.handlers[:]:
19
- current_logger.removeHandler(handler)
20
-
21
- for handler_r in root_handlers:
22
- if type(handler_r) is logging.StreamHandler:
23
- new_handler = logging.StreamHandler()
24
- new_handler.terminator = ""
25
- new_handler.setFormatter(logging.Formatter("%(message)s"))
26
- current_logger.addHandler(new_handler)
27
- elif type(handler_r) is logging.FileHandler:
28
- new_handler = logging.FileHandler(
29
- handler_r.baseFilename,
30
- handler_r.mode,
31
- handler_r.encoding,
32
- handler_r.delay,
33
- handler_r.errors,
34
- )
35
- new_handler.terminator = "" # This will stop the printing in new line
36
- new_handler.setFormatter(logging.Formatter("%(message)s"))
37
- current_logger.addHandler(new_handler)
38
- else:
39
- continue
40
- current_logger.propagate = False # Don't propagate to root logger
41
- return current_logger
bulk_chain/demo.py DELETED
@@ -1,84 +0,0 @@
1
- import json
2
-
3
- import argparse
4
- import logging
5
- import sys
6
-
7
- from source_iter.service_jsonl import JsonlService
8
-
9
- from bulk_chain.api import init_llm
10
- from bulk_chain.core.service_args import CmdArgsService
11
- from bulk_chain.core.service_json import JsonService
12
- from bulk_chain.core.service_llm import chat_with_lm
13
- from bulk_chain.core.service_schema import SchemaService
14
- from bulk_chain.core.utils import parse_filepath
15
-
16
- logger = logging.getLogger(__name__)
17
- logging.basicConfig(level=logging.INFO)
18
-
19
-
20
- def iter_from_json(filepath):
21
- with open(filepath, "r") as f:
22
- content = json.load(f)
23
- for key, value in content.items():
24
- yield key, value
25
-
26
-
27
- def iter_from_text_file(filepath):
28
- with open(filepath, "r") as f:
29
- yield filepath.split('.')[0], f.read()
30
-
31
-
32
- if __name__ == '__main__':
33
-
34
- parser = argparse.ArgumentParser(description="LLM demo usage based on CoT schema")
35
- parser.add_argument('--adapter', dest='adapter', type=str, default=None)
36
- parser.add_argument('--attempts', dest='attempts', type=int, default=None)
37
- parser.add_argument('--src', dest='src', type=str, nargs="*", default=None)
38
- parser.add_argument('--schema', dest='schema', type=str, default=None,
39
- help="Path to the JSON file that describes schema")
40
- parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
41
- help="Optional trimming prompt by the specified amount of characters.")
42
-
43
- # Extract native arguments.
44
- native_args = CmdArgsService.extract_native_args(sys.argv, end_prefix="%%")
45
- args = parser.parse_args(args=native_args[1:])
46
-
47
- # Extract model-related arguments and Initialize Large Language Model.
48
- model_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%m", end_prefix="%%")
49
- model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": args.attempts}
50
- llm, llm_model_name = init_llm(adapter=args.adapter, **model_args_dict)
51
-
52
- # Setup schema.
53
- schema = SchemaService(json_data=JsonService.read(args.schema))
54
- schema_name = schema.src.get("name", None)
55
- if schema is not None:
56
- logger.info(f"Using schema: {schema_name}")
57
-
58
- output_providers = {
59
- "jsonl": lambda filepath, data_it, header:
60
- JsonlService.write(target=filepath,
61
- data_it=map(lambda item: {key: item[i] for i, key in enumerate(header)}, data_it))
62
- }
63
-
64
- input_file_handlers = {
65
- "json": lambda filepath: iter_from_json(filepath),
66
- "txt": lambda filepath: iter_from_text_file(filepath)
67
- }
68
-
69
- # Input extension type defines the provider.
70
- if args.src is None:
71
- args.src = []
72
- if isinstance(args.src, str):
73
- args.src = [args.src]
74
- sources = [parse_filepath(s) for s in args.src]
75
-
76
- preset_dict = {}
77
- for fp, ext, _ in sources:
78
- for key, value in input_file_handlers[ext](fp):
79
- if key in preset_dict:
80
- raise Exception(f"While at handling {fp}: Key {key} is already registered!")
81
- preset_dict[key] = value
82
-
83
- # Launch Demo.
84
- chat_with_lm(llm, preset_dict=preset_dict, schema=schema, model_name=llm_model_name)
bulk_chain/infer.py DELETED
@@ -1,193 +0,0 @@
1
- from itertools import chain
2
- from os.path import join, basename
3
-
4
- import argparse
5
- import logging
6
- import sys
7
-
8
- from source_iter.service_csv import CsvService
9
- from source_iter.service_jsonl import JsonlService
10
- from tqdm import tqdm
11
-
12
- from bulk_chain.api import INFER_MODES, _infer_batch, CWD, init_llm
13
- from bulk_chain.core.llm_base import BaseLM
14
- from bulk_chain.core.provider_sqlite import SQLite3Service
15
- from bulk_chain.core.service_args import CmdArgsService
16
- from bulk_chain.core.service_batch import BatchIterator
17
- from bulk_chain.core.service_dict import DictionaryService
18
- from bulk_chain.core.service_json import JsonService
19
- from bulk_chain.core.service_schema import SchemaService
20
- from bulk_chain.core.utils import handle_table_name, optional_limit_iter, parse_filepath
21
-
22
- logger = logging.getLogger(__name__)
23
- logging.basicConfig(level=logging.INFO)
24
-
25
- WRITER_PROVIDERS = {
26
- "sqlite": lambda filepath, table_name, data_it, **kwargs: SQLite3Service.write(
27
- data_it=data_it, target=filepath, table_name=table_name, skip_existed=True, **kwargs)
28
- }
29
-
30
- READER_PROVIDERS = {
31
- "sqlite": lambda filepath, table_name: SQLite3Service.read(filepath, table=table_name)
32
- }
33
-
34
-
35
- def infer_batch(batch, columns=None, **kwargs):
36
- assert (len(batch) > 0)
37
- # TODO. Support proper selection of columns.
38
- cols = batch[0].keys() if columns is None else columns
39
- return _infer_batch(batch=batch, cols=cols, **kwargs)
40
-
41
-
42
- def raise_(ex):
43
- raise ex
44
-
45
-
46
- def iter_content_cached(input_dicts_it, llm, schema, cache_target, batch_size, id_column_name, limit_prompt=None,
47
- **cache_kwargs):
48
- assert (isinstance(llm, BaseLM))
49
- assert (isinstance(cache_target, str))
50
-
51
- # Quick initialization of the schema.
52
- if isinstance(schema, str):
53
- schema = JsonService.read(schema)
54
- if isinstance(schema, dict):
55
- schema = SchemaService(json_data=schema)
56
-
57
- # Parse target.
58
- cache_filepath, _, cache_table = parse_filepath(filepath=cache_target)
59
-
60
- # Iterator of the queries.
61
- prompts_it = map(
62
- lambda data: DictionaryService.custom_update(src_dict=data, other_dict=schema.cot_args),
63
- input_dicts_it
64
- )
65
-
66
- prompts_batched_it = BatchIterator(
67
- data_iter=iter(tqdm(prompts_it, desc="Iter Content")),
68
- batch_size=batch_size,
69
- filter_func=lambda data: not SQLite3Service.entry_exist(
70
- id_column_name=id_column_name, table_name=cache_table, target=cache_filepath,
71
- id_value=data[id_column_name], **cache_kwargs)
72
- )
73
-
74
- results_it = map(
75
- lambda batch: infer_batch(
76
- batch=batch, schema=schema,
77
- handle_batch_func=lambda batch: INFER_MODES["batch"](llm, batch, limit_prompt),
78
- handle_missed_value_func=lambda col_name: raise_(
79
- Exception(f"Value for {col_name} is undefined. Filling undefined values is not supported")
80
- )
81
- ),
82
- prompts_batched_it
83
- )
84
-
85
- # Perform caching first.
86
- WRITER_PROVIDERS["sqlite"](
87
- filepath=cache_filepath,
88
- table_name=cache_table,
89
- data_it=chain.from_iterable(results_it),
90
- id_column_name=id_column_name,
91
- **cache_kwargs)
92
-
93
- # Then retrieve data.
94
- return READER_PROVIDERS["sqlite"](filepath=cache_filepath, table_name=cache_table)
95
-
96
-
97
- if __name__ == '__main__':
98
-
99
- parser = argparse.ArgumentParser(description="Infer Instruct LLM inference based on CoT schema")
100
- parser.add_argument('--adapter', dest='adapter', type=str, default=None)
101
- parser.add_argument('--id-col', dest='id_col', type=str, default="uid")
102
- parser.add_argument('--src', dest='src', type=str, nargs="?", default=None)
103
- parser.add_argument('--schema', dest='schema', type=str, default=None,
104
- help="Path to the JSON file that describes schema")
105
- parser.add_argument('--to', dest='to', type=str, default=None, choices=["csv", "sqlite"])
106
- parser.add_argument('--output', dest='output', type=str, default=None)
107
- parser.add_argument('--limit', dest='limit', type=int, default=None,
108
- help="Limit amount of source texts for prompting.")
109
- parser.add_argument('--batch-size', dest='batch_size', type=int, default=1)
110
- parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
111
- help="Optional trimming prompt by the specified amount of characters.")
112
-
113
- # Extract native arguments.
114
- native_args = CmdArgsService.extract_native_args(sys.argv, end_prefix="%%")
115
- args = parser.parse_args(args=native_args[1:])
116
-
117
- # Extract csv-related arguments.
118
- csv_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%csv", end_prefix="%%")
119
- csv_args_dict = CmdArgsService.args_to_dict(csv_args)
120
-
121
- # Extract model-related arguments and Initialize Large Language Model.
122
- model_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%m", end_prefix="%%")
123
- model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": 1}
124
- llm, llm_model_name = init_llm(adapter=args.adapter, **model_args_dict)
125
-
126
- # Setup schema.
127
- schema = SchemaService(json_data=JsonService.read(args.schema))
128
- schema_name = schema.src.get("name", None)
129
- if schema is not None:
130
- logger.info(f"Using schema: {schema_name}")
131
-
132
- input_providers = {
133
- "csv": lambda filepath: CsvService.read(src=filepath, row_id_key=args.id_col,
134
- as_dict=True, skip_header=True,
135
- delimiter=csv_args_dict.get("delimiter", ","),
136
- escapechar=csv_args_dict.get("escapechar", None)),
137
- "tsv": lambda filepath: CsvService.read(src=filepath, row_id_key=args.id_col,
138
- as_dict=True, skip_header=True,
139
- delimiter=csv_args_dict.get("delimiter", "\t"),
140
- escapechar=csv_args_dict.get("escapechar", None)),
141
- "jsonl": lambda filepath: JsonlService.read(src=filepath, row_id_key=args.id_col)
142
- }
143
-
144
- output_providers = {
145
- "csv": lambda filepath, data_it, header: CsvService.write(target=filepath,
146
- data_it=data_it, header=header,
147
- delimiter=csv_args_dict.get("delimiter", ","),
148
- escapechar=csv_args_dict.get("escapechar", None),
149
- it_type=None),
150
- "tsv": lambda filepath, data_it, header: CsvService.write(target=filepath,
151
- data_it=data_it, header=header,
152
- delimiter=csv_args_dict.get("delimiter", "\t"),
153
- escapechar=csv_args_dict.get("escapechar", None),
154
- it_type=None),
155
- "jsonl": lambda filepath, data_it, header:
156
- JsonlService.write(target=filepath,
157
- data_it=map(lambda item: {key: item[i] for i, key in enumerate(header)}, data_it))
158
- }
159
-
160
- # Setup output.
161
- args.output = args.output.format(model=llm.name()) if args.output is not None else args.output
162
- tgt_filepath, tgt_ext, tgt_meta = parse_filepath(args.output, default_ext=args.to)
163
-
164
- # We do not support multiple files for other modes.
165
- src_filepath, src_ext, src_meta = parse_filepath(args.src)
166
-
167
- def default_output_file_template(ext):
168
- # This is a default template for output files to be generated.
169
- return "".join(["_".join([join(CWD, basename(src_filepath)), llm.name(), schema_name]), ext])
170
-
171
- # Setup cache target as well as the related table.
172
- cache_filepath = default_output_file_template(".sqlite") if tgt_filepath is None else tgt_filepath
173
- cache_table = handle_table_name(tgt_meta if tgt_meta is not None else "contents")
174
-
175
- # This is a content that we extracted via input provider.
176
- it_data = input_providers[src_ext](src_filepath)
177
-
178
- data_it = iter_content_cached(input_dicts_it=optional_limit_iter(it_data=it_data, limit=args.limit),
179
- limit_prompt=args.limit_prompt,
180
- schema=schema,
181
- llm=llm,
182
- batch_size=args.batch_size,
183
- id_column_name=args.id_col,
184
- cache_target=":".join([cache_filepath, cache_table]))
185
-
186
- # Setup output target
187
- tgt_ext = src_ext if tgt_ext is None else tgt_ext
188
- output_target = default_output_file_template(f".{tgt_ext}") if tgt_filepath is None else tgt_filepath
189
-
190
- # Perform output writing process.
191
- output_providers[tgt_ext](filepath=output_target,
192
- data_it=data_it,
193
- header=SQLite3Service.read_columns(target=cache_filepath, table=cache_table))
@@ -1,21 +0,0 @@
1
- bulk_chain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- bulk_chain/api.py,sha256=8hXJb66bEOf1izgeBmjrB9LexSMJD7GquUhIm76lfmY,4351
3
- bulk_chain/demo.py,sha256=20r_-ioR3fu3eqHJnCRK4aQmBKTMgjFAHzZPJcXaEz8,3186
4
- bulk_chain/infer.py,sha256=Qb7ZV3_DXIh1jQoEjER_H2rO4ntREj_iIwGpEpAXJCE,9157
5
- bulk_chain/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- bulk_chain/core/llm_base.py,sha256=fuWxfEOSRYvnoZMOcfnq1E2LIJKnrpsnxQ1z6SmY1nM,1839
7
- bulk_chain/core/provider_sqlite.py,sha256=sW4Yefp_zYuL8xVys8la5hG0Ng94jSqiXelPgGGB5B0,5327
8
- bulk_chain/core/service_args.py,sha256=lq4Veuh4QNu8mlCv8MT9S1rMxTn4FKalyp-3boYonVk,2136
9
- bulk_chain/core/service_batch.py,sha256=z1OND6x40QBvK2H_Wt8sRS8MadUvYTCHmbS_dCm9t7M,1678
10
- bulk_chain/core/service_data.py,sha256=OWWHHnr_plwxYTxLuvMrhEc1PbSx-XC3rbFzV0hy3vk,1107
11
- bulk_chain/core/service_dict.py,sha256=lAghLU-3V3xYGv5BTA327Qcw8UJYmgQRMFdggzlrUgo,383
12
- bulk_chain/core/service_json.py,sha256=6o1xM_8c9QEjH9Q3qEmJylU9nahfRXhUd5sFF2dGJwo,182
13
- bulk_chain/core/service_llm.py,sha256=3WYoBgaiqoRwsoKq6VUUNMasbLj5rMCjGkU3OQVxGf8,2278
14
- bulk_chain/core/service_schema.py,sha256=KIP4n0Tz2h1i7SIMGhgAhoiCgUFXOT1rzMt38yACS2U,1154
15
- bulk_chain/core/utils.py,sha256=1irk3RGLzJxeLZ-Tv0oOceOKG0ADtpEWniG2UGbYA_U,3089
16
- bulk_chain/core/utils_logger.py,sha256=BD-ADxaeeuHztaYjqtIY_cIzc5r2Svq9XwRtrgIEqyI,1636
17
- bulk_chain-0.25.3.dist-info/LICENSE,sha256=VF9SjNpwwSSFEY_eP_8A1ocDCrbwfjI1pZexXdCkOwo,1076
18
- bulk_chain-0.25.3.dist-info/METADATA,sha256=-yt6rNoJSGqHRe-6z-mdf51wtwpVLYvcr8Yv_1qkdxQ,6037
19
- bulk_chain-0.25.3.dist-info/WHEEL,sha256=pL8R0wFFS65tNSRnaOVrsw9EOkOqxLrlUPenUYnJKNo,91
20
- bulk_chain-0.25.3.dist-info/top_level.txt,sha256=Hxq_wyH-GDXKBaA63UfBIiMJO2eCHJG5EOrXDphpeB4,11
21
- bulk_chain-0.25.3.dist-info/RECORD,,