cat-stack 1.6.8__tar.gz → 2.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_stack-1.6.8 → cat_stack-2.0.0}/.gitignore +8 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/PKG-INFO +62 -1
- {cat_stack-1.6.8 → cat_stack-2.0.0}/README.md +61 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/__about__.py +1 -1
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/__init__.py +2 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_formatter.py +90 -8
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_providers.py +470 -168
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/calls/CoVe.py +68 -53
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/calls/image_CoVe.py +60 -43
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/calls/image_stepback.py +16 -7
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/calls/pdf_CoVe.py +50 -41
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/calls/pdf_stepback.py +16 -7
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/calls/stepback.py +21 -12
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/calls/top_n.py +18 -10
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/classify.py +121 -4
- cat_stack-2.0.0/src/catstack/collapse_themes.py +479 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/image_functions.py +92 -49
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/pdf_functions.py +88 -46
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/text_functions_ensemble.py +14 -2
- {cat_stack-1.6.8 → cat_stack-2.0.0}/LICENSE +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/pyproject.toml +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/cat_stack/__init__.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_batch.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_category_analysis.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_chunked.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_embeddings.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_pilot_test.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_prompts.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_review_ui.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_tiebreaker.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_utils.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_web_fetch.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/_wrapper_helpers.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/calls/__init__.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/explore.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/extract.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/images/circle.png +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/images/cube.png +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/images/diamond.png +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/images/overlapping_pentagons.png +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/images/rectangles.png +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/model_reference_list.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/prompt_tune.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/summarize.py +0 -0
- {cat_stack-1.6.8 → cat_stack-2.0.0}/src/catstack/text_functions.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-stack
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.0.0
|
|
4
4
|
Summary: Domain-agnostic text, image, PDF, and DOCX classification engine powered by LLMs
|
|
5
5
|
Project-URL: Documentation, https://github.com/chrissoria/cat-stack#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-stack/issues
|
|
@@ -154,6 +154,67 @@ cat.explore(
|
|
|
154
154
|
)
|
|
155
155
|
```
|
|
156
156
|
|
|
157
|
+
### `collapse_themes()`
|
|
158
|
+
Consolidate a long, redundant list of extracted category labels (e.g. the output of `explore()`) into a smaller, deduplicated taxonomy. Runs the semantic merge iteratively, then applies a single deterministic embedding re-merge over the whole result to collapse cross-batch lexical siblings (e.g. "tension" / "estrangement") that batched passes leave separate. Tuned to err toward over-segmentation (keeping categories) rather than over-merging.
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
# Basic: aggressive merge, auto-stop at the quality peak
|
|
162
|
+
cat.collapse_themes(
|
|
163
|
+
input_data=raw_labels, # list[str] or a frequency Series/dict
|
|
164
|
+
api_key=key,
|
|
165
|
+
description="Why did you move?", # the survey question / context
|
|
166
|
+
aggressive=True,
|
|
167
|
+
passes="auto",
|
|
168
|
+
user_model="gpt-4o",
|
|
169
|
+
)
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
```python
|
|
173
|
+
# Per-step model assignment: a cheap model thins restatements,
|
|
174
|
+
# a stronger model does the conceptual merge (providers can differ)
|
|
175
|
+
cat.collapse_themes(
|
|
176
|
+
input_data=raw_labels,
|
|
177
|
+
api_key=key,
|
|
178
|
+
description="Why did you move?",
|
|
179
|
+
aggressive=True,
|
|
180
|
+
passes="auto",
|
|
181
|
+
unique_model="Qwen/Qwen2.5-72B-Instruct:together",
|
|
182
|
+
unique_model_source="huggingface",
|
|
183
|
+
unique_passes=1,
|
|
184
|
+
merge_model="Qwen/Qwen3.6-35B-A3B:together",
|
|
185
|
+
merge_model_source="huggingface",
|
|
186
|
+
max_workers=8,
|
|
187
|
+
)
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
**Parameters**
|
|
191
|
+
|
|
192
|
+
| Parameter | Default | Description |
|
|
193
|
+
| --- | --- | --- |
|
|
194
|
+
| `input_data` | — | List of category labels, or a frequency `Series`/`dict` (`label -> count`). |
|
|
195
|
+
| `api_key` | `None` | API key for the LLM provider (required). |
|
|
196
|
+
| `description` | `""` | The survey question or context, used in the merge prompt. |
|
|
197
|
+
| `passes` | `1` | Number of merge iterations, or `"auto"` to iterate until the embedding-quality benchmark peaks. |
|
|
198
|
+
| `max_passes` | `10` | Cap on iterations when `passes="auto"`. |
|
|
199
|
+
| `batch_size` | `40` | Labels per LLM chunk (`ceil(n / batch_size)` calls per pass). |
|
|
200
|
+
| `aggressive` | `False` | `True` = conceptual-merge prompt (compress related labels); `False` = extract-unique (faithful thinning, removes restatements only). |
|
|
201
|
+
| `dedupe_threshold` | `0.95` | Jaro-Winkler similarity at/above which normalized labels are deduped (`1.0` = exact only). |
|
|
202
|
+
| `embedding_merge_threshold` | `0.92` | Cosine similarity at/above which labels are merged in the pre-LLM embedding step. `None`/`>=1.0` disables it. |
|
|
203
|
+
| `shuffle` | `True` | Randomize order each pass so batch composition varies (improves convergence stability). |
|
|
204
|
+
| `final_consolidation` | `0.82` | Cosine threshold for one greedy global embedding re-merge after all passes, collapsing cross-batch duplicates. Conservative by design (errs toward keeping categories). `False`/`None` skips it. |
|
|
205
|
+
| `user_model` | `"gpt-4o"` | Model for the merge phase. Use a capable model — small models can degenerate. |
|
|
206
|
+
| `model_source` | `"auto"` | Provider for `user_model` (`"auto"`, `"openai"`, `"huggingface"`, …). |
|
|
207
|
+
| `unique_model` | `None` | If set, run an initial extract-unique thinning phase on this (typically cheaper) model before the merge phase. `None` skips the phase (backward compatible). |
|
|
208
|
+
| `unique_model_source` | `"auto"` | Provider for `unique_model` — can differ from the merge phase. |
|
|
209
|
+
| `unique_passes` | `1` | Number of thinning passes when `unique_model` is set. |
|
|
210
|
+
| `merge_model` | `None` | Model for the merge phase; falls back to `user_model` when `None`. |
|
|
211
|
+
| `merge_model_source` | `"auto"` | Provider for `merge_model`. |
|
|
212
|
+
| `creativity` | `0` | Temperature (`0` = deterministic). |
|
|
213
|
+
| `max_workers` | `1` | Batches processed concurrently per pass. |
|
|
214
|
+
| `random_state` | `None` | Seed for shuffling (per-pass seed = `random_state + pass`). |
|
|
215
|
+
| `filename` | `None` | Optional CSV path to save the final list. |
|
|
216
|
+
| `progress_callback` | `None` | Optional `callback(pass, passes, label)` for progress reporting. |
|
|
217
|
+
|
|
157
218
|
### `summarize()`
|
|
158
219
|
Summarize text or PDF documents, with optional multi-model ensemble.
|
|
159
220
|
|
|
@@ -118,6 +118,67 @@ cat.explore(
|
|
|
118
118
|
)
|
|
119
119
|
```
|
|
120
120
|
|
|
121
|
+
### `collapse_themes()`
|
|
122
|
+
Consolidate a long, redundant list of extracted category labels (e.g. the output of `explore()`) into a smaller, deduplicated taxonomy. Runs the semantic merge iteratively, then applies a single deterministic embedding re-merge over the whole result to collapse cross-batch lexical siblings (e.g. "tension" / "estrangement") that batched passes leave separate. Tuned to err toward over-segmentation (keeping categories) rather than over-merging.
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
# Basic: aggressive merge, auto-stop at the quality peak
|
|
126
|
+
cat.collapse_themes(
|
|
127
|
+
input_data=raw_labels, # list[str] or a frequency Series/dict
|
|
128
|
+
api_key=key,
|
|
129
|
+
description="Why did you move?", # the survey question / context
|
|
130
|
+
aggressive=True,
|
|
131
|
+
passes="auto",
|
|
132
|
+
user_model="gpt-4o",
|
|
133
|
+
)
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
```python
|
|
137
|
+
# Per-step model assignment: a cheap model thins restatements,
|
|
138
|
+
# a stronger model does the conceptual merge (providers can differ)
|
|
139
|
+
cat.collapse_themes(
|
|
140
|
+
input_data=raw_labels,
|
|
141
|
+
api_key=key,
|
|
142
|
+
description="Why did you move?",
|
|
143
|
+
aggressive=True,
|
|
144
|
+
passes="auto",
|
|
145
|
+
unique_model="Qwen/Qwen2.5-72B-Instruct:together",
|
|
146
|
+
unique_model_source="huggingface",
|
|
147
|
+
unique_passes=1,
|
|
148
|
+
merge_model="Qwen/Qwen3.6-35B-A3B:together",
|
|
149
|
+
merge_model_source="huggingface",
|
|
150
|
+
max_workers=8,
|
|
151
|
+
)
|
|
152
|
+
```
|
|
153
|
+
|
|
154
|
+
**Parameters**
|
|
155
|
+
|
|
156
|
+
| Parameter | Default | Description |
|
|
157
|
+
| --- | --- | --- |
|
|
158
|
+
| `input_data` | — | List of category labels, or a frequency `Series`/`dict` (`label -> count`). |
|
|
159
|
+
| `api_key` | `None` | API key for the LLM provider (required). |
|
|
160
|
+
| `description` | `""` | The survey question or context, used in the merge prompt. |
|
|
161
|
+
| `passes` | `1` | Number of merge iterations, or `"auto"` to iterate until the embedding-quality benchmark peaks. |
|
|
162
|
+
| `max_passes` | `10` | Cap on iterations when `passes="auto"`. |
|
|
163
|
+
| `batch_size` | `40` | Labels per LLM chunk (`ceil(n / batch_size)` calls per pass). |
|
|
164
|
+
| `aggressive` | `False` | `True` = conceptual-merge prompt (compress related labels); `False` = extract-unique (faithful thinning, removes restatements only). |
|
|
165
|
+
| `dedupe_threshold` | `0.95` | Jaro-Winkler similarity at/above which normalized labels are deduped (`1.0` = exact only). |
|
|
166
|
+
| `embedding_merge_threshold` | `0.92` | Cosine similarity at/above which labels are merged in the pre-LLM embedding step. `None`/`>=1.0` disables it. |
|
|
167
|
+
| `shuffle` | `True` | Randomize order each pass so batch composition varies (improves convergence stability). |
|
|
168
|
+
| `final_consolidation` | `0.82` | Cosine threshold for one greedy global embedding re-merge after all passes, collapsing cross-batch duplicates. Conservative by design (errs toward keeping categories). `False`/`None` skips it. |
|
|
169
|
+
| `user_model` | `"gpt-4o"` | Model for the merge phase. Use a capable model — small models can degenerate. |
|
|
170
|
+
| `model_source` | `"auto"` | Provider for `user_model` (`"auto"`, `"openai"`, `"huggingface"`, …). |
|
|
171
|
+
| `unique_model` | `None` | If set, run an initial extract-unique thinning phase on this (typically cheaper) model before the merge phase. `None` skips the phase (backward compatible). |
|
|
172
|
+
| `unique_model_source` | `"auto"` | Provider for `unique_model` — can differ from the merge phase. |
|
|
173
|
+
| `unique_passes` | `1` | Number of thinning passes when `unique_model` is set. |
|
|
174
|
+
| `merge_model` | `None` | Model for the merge phase; falls back to `user_model` when `None`. |
|
|
175
|
+
| `merge_model_source` | `"auto"` | Provider for `merge_model`. |
|
|
176
|
+
| `creativity` | `0` | Temperature (`0` = deterministic). |
|
|
177
|
+
| `max_workers` | `1` | Batches processed concurrently per pass. |
|
|
178
|
+
| `random_state` | `None` | Seed for shuffling (per-pass seed = `random_state + pass`). |
|
|
179
|
+
| `filename` | `None` | Optional CSV path to save the final list. |
|
|
180
|
+
| `progress_callback` | `None` | Optional `callback(pass, passes, label)` for progress reporting. |
|
|
181
|
+
|
|
121
182
|
### `summarize()`
|
|
122
183
|
Summarize text or PDF documents, with optional multi-model ensemble.
|
|
123
184
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
4
|
-
__version__ = "
|
|
4
|
+
__version__ = "2.0.0"
|
|
5
5
|
__author__ = "Chris Soria"
|
|
6
6
|
__email__ = "chrissoria@berkeley.edu"
|
|
7
7
|
__title__ = "cat-stack"
|
|
@@ -18,6 +18,7 @@ from .__about__ import (
|
|
|
18
18
|
# Main entry points
|
|
19
19
|
from .extract import extract
|
|
20
20
|
from .explore import explore
|
|
21
|
+
from .collapse_themes import collapse_themes
|
|
21
22
|
from .classify import classify
|
|
22
23
|
from .summarize import summarize
|
|
23
24
|
from .prompt_tune import prompt_tune
|
|
@@ -103,6 +104,7 @@ __all__ = [
|
|
|
103
104
|
# Main entry points
|
|
104
105
|
"extract",
|
|
105
106
|
"explore",
|
|
107
|
+
"collapse_themes",
|
|
106
108
|
"classify",
|
|
107
109
|
"summarize",
|
|
108
110
|
"prompt_tune",
|
|
@@ -45,6 +45,11 @@ def _check_dependencies():
|
|
|
45
45
|
def _check_dependencies_installed() -> bool:
|
|
46
46
|
"""Pure check — returns True if all formatter deps import successfully.
|
|
47
47
|
No side effects, no install attempts."""
|
|
48
|
+
# If a dep was just pip-installed in this process's lifetime, the import
|
|
49
|
+
# system may have cached its earlier absence; clear that so the re-check
|
|
50
|
+
# actually sees the freshly-installed package.
|
|
51
|
+
import importlib
|
|
52
|
+
importlib.invalidate_caches()
|
|
48
53
|
try:
|
|
49
54
|
import torch # noqa: F401
|
|
50
55
|
import transformers # noqa: F401
|
|
@@ -165,7 +170,31 @@ def _ensure_dependencies(verbose: bool = True) -> bool:
|
|
|
165
170
|
" To skip this and disable the formatter, pass json_formatter=False."
|
|
166
171
|
)
|
|
167
172
|
|
|
168
|
-
|
|
173
|
+
ok = _install_dependencies(verbose=verbose)
|
|
174
|
+
if not ok:
|
|
175
|
+
# Freshly pip-installed packages (esp. compiled ones like torch) often
|
|
176
|
+
# cannot be imported by the SAME running process — but they ARE on disk
|
|
177
|
+
# now. Tell the user to re-run rather than silently degrading every row
|
|
178
|
+
# to an error.
|
|
179
|
+
if verbose and _deps_on_disk():
|
|
180
|
+
print(
|
|
181
|
+
"[CatLLM] Formatter dependencies were just installed but cannot "
|
|
182
|
+
"be imported into the already-running process. Please RE-RUN your "
|
|
183
|
+
"command — they will load on the next start. (Avoid this by "
|
|
184
|
+
"pre-installing: pip install 'cat-stack[formatter]'.)"
|
|
185
|
+
)
|
|
186
|
+
return ok
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _deps_on_disk() -> bool:
|
|
190
|
+
"""True if the formatter deps are findable on disk (importable in a FRESH
|
|
191
|
+
process) even if they failed to import in the current one."""
|
|
192
|
+
import importlib.util
|
|
193
|
+
try:
|
|
194
|
+
return all(importlib.util.find_spec(m) is not None
|
|
195
|
+
for m in ("torch", "transformers", "accelerate"))
|
|
196
|
+
except (ImportError, ValueError):
|
|
197
|
+
return False
|
|
169
198
|
|
|
170
199
|
|
|
171
200
|
def _is_model_cached() -> bool:
|
|
@@ -205,6 +234,51 @@ def ensure_formatter_available() -> bool:
|
|
|
205
234
|
return True # actual download happens in load_formatter()
|
|
206
235
|
|
|
207
236
|
|
|
237
|
+
def _load_formatter_tokenizer(AutoTokenizer):
|
|
238
|
+
"""Load the formatter tokenizer, defending against a malformed
|
|
239
|
+
`tokenizer_config.json`.
|
|
240
|
+
|
|
241
|
+
Some published configs store `extra_special_tokens` as a LIST, but
|
|
242
|
+
transformers 4.56–4.57.x expect a dict and crash in
|
|
243
|
+
`_set_model_specific_special_tokens` with
|
|
244
|
+
`'list' object has no attribute 'keys'`. On that failure we snapshot the
|
|
245
|
+
repo locally, normalize a list-valued `extra_special_tokens` to `{}`
|
|
246
|
+
(the tokens already live in `added_tokens`/`special_tokens_map`, so
|
|
247
|
+
dropping the field is lossless), and load from the patched local copy.
|
|
248
|
+
"""
|
|
249
|
+
try:
|
|
250
|
+
return AutoTokenizer.from_pretrained(
|
|
251
|
+
_MERGED_MODEL_REPO, trust_remote_code=True
|
|
252
|
+
)
|
|
253
|
+
except (AttributeError, TypeError) as e:
|
|
254
|
+
if "keys" not in str(e) and "extra_special_tokens" not in str(e):
|
|
255
|
+
raise
|
|
256
|
+
import json
|
|
257
|
+
import os
|
|
258
|
+
from huggingface_hub import snapshot_download
|
|
259
|
+
|
|
260
|
+
local_dir = snapshot_download(_MERGED_MODEL_REPO)
|
|
261
|
+
cfg_path = os.path.join(local_dir, "tokenizer_config.json")
|
|
262
|
+
with open(cfg_path) as f:
|
|
263
|
+
cfg = json.load(f)
|
|
264
|
+
if isinstance(cfg.get("extra_special_tokens"), list):
|
|
265
|
+
cfg["extra_special_tokens"] = {}
|
|
266
|
+
# snapshot dirs are often read-only symlink caches; patch a copy.
|
|
267
|
+
import tempfile
|
|
268
|
+
import shutil
|
|
269
|
+
patched = tempfile.mkdtemp(prefix="catllm_formatter_tok_")
|
|
270
|
+
for fn in os.listdir(local_dir):
|
|
271
|
+
src = os.path.join(local_dir, fn)
|
|
272
|
+
if os.path.isfile(src):
|
|
273
|
+
shutil.copy(src, os.path.join(patched, fn))
|
|
274
|
+
with open(os.path.join(patched, "tokenizer_config.json"), "w") as f:
|
|
275
|
+
json.dump(cfg, f)
|
|
276
|
+
print("[CatLLM] Patched malformed extra_special_tokens in the "
|
|
277
|
+
"formatter tokenizer config (list -> {}).")
|
|
278
|
+
return AutoTokenizer.from_pretrained(patched, trust_remote_code=True)
|
|
279
|
+
raise
|
|
280
|
+
|
|
281
|
+
|
|
208
282
|
def load_formatter(device=None):
|
|
209
283
|
"""
|
|
210
284
|
Load the merged formatter model and tokenizer.
|
|
@@ -230,15 +304,21 @@ def load_formatter(device=None):
|
|
|
230
304
|
dtype = torch.float16 if device == "cuda" else torch.float32
|
|
231
305
|
|
|
232
306
|
print(f"[CatLLM] Loading JSON formatter on {device}...")
|
|
233
|
-
tokenizer = AutoTokenizer
|
|
234
|
-
_MERGED_MODEL_REPO, trust_remote_code=True
|
|
235
|
-
)
|
|
307
|
+
tokenizer = _load_formatter_tokenizer(AutoTokenizer)
|
|
236
308
|
if tokenizer.pad_token is None:
|
|
237
309
|
tokenizer.pad_token = tokenizer.eos_token
|
|
238
310
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
311
|
+
# `dtype=` is the transformers >=4.56 kwarg; older versions only accept
|
|
312
|
+
# `torch_dtype=` and crash if `dtype=` leaks into the config. Try the new
|
|
313
|
+
# name, fall back to the old one.
|
|
314
|
+
try:
|
|
315
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
316
|
+
_MERGED_MODEL_REPO, dtype=dtype, trust_remote_code=True
|
|
317
|
+
)
|
|
318
|
+
except TypeError:
|
|
319
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
320
|
+
_MERGED_MODEL_REPO, torch_dtype=dtype, trust_remote_code=True
|
|
321
|
+
)
|
|
242
322
|
model = model.to(device)
|
|
243
323
|
model.eval()
|
|
244
324
|
|
|
@@ -281,7 +361,9 @@ def run_formatter(raw_output, categories, model, tokenizer, device):
|
|
|
281
361
|
with torch.no_grad():
|
|
282
362
|
out = model.generate(
|
|
283
363
|
**inputs,
|
|
284
|
-
|
|
364
|
+
# 512 (was 128): a large category set produces a long N-key JSON
|
|
365
|
+
# object; 128 tokens truncated it for 28/48-category tasks.
|
|
366
|
+
max_new_tokens=512,
|
|
285
367
|
do_sample=False,
|
|
286
368
|
temperature=None,
|
|
287
369
|
top_p=None,
|