xfmr-zem 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xfmr_zem/__init__.py +35 -0
- xfmr_zem/cli.py +295 -0
- xfmr_zem/client.py +208 -0
- xfmr_zem/orchestrators/parallel_local.py +92 -0
- xfmr_zem/schemas.py +15 -0
- xfmr_zem/server.py +188 -0
- xfmr_zem/servers/data_juicer/parameter.yaml +17 -0
- xfmr_zem/servers/data_juicer/server.py +113 -0
- xfmr_zem/servers/instruction_gen/parameter.yaml +12 -0
- xfmr_zem/servers/instruction_gen/server.py +90 -0
- xfmr_zem/servers/io/parameter.yaml +10 -0
- xfmr_zem/servers/io/server.py +95 -0
- xfmr_zem/servers/llm/server.py +47 -0
- xfmr_zem/servers/nemo_curator/parameter.yaml +17 -0
- xfmr_zem/servers/nemo_curator/server.py +118 -0
- xfmr_zem/servers/profiler/server.py +76 -0
- xfmr_zem/servers/sinks/server.py +48 -0
- xfmr_zem/zenml_wrapper.py +203 -0
- xfmr_zem-0.2.0.dist-info/METADATA +152 -0
- xfmr_zem-0.2.0.dist-info/RECORD +23 -0
- xfmr_zem-0.2.0.dist-info/WHEEL +4 -0
- xfmr_zem-0.2.0.dist-info/entry_points.txt +3 -0
- xfmr_zem-0.2.0.dist-info/licenses/LICENSE +201 -0
xfmr_zem/server.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
|
|
2
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
3
|
+
import yaml
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from fastmcp import FastMCP
|
|
6
|
+
import inspect
|
|
7
|
+
|
|
8
|
+
class ZemServer(FastMCP):
|
|
9
|
+
"""
|
|
10
|
+
Base class for Zem MCP Servers.
|
|
11
|
+
Extends FastMCP to support parameter loading and standardized tool registration.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
name: str,
|
|
17
|
+
Dependencies: Optional[List[str]] = None,
|
|
18
|
+
parameter_file: Optional[str] = None,
|
|
19
|
+
**kwargs
|
|
20
|
+
):
|
|
21
|
+
super().__init__(name=name, **kwargs)
|
|
22
|
+
self.parameter_file = parameter_file
|
|
23
|
+
self.parameters = {}
|
|
24
|
+
|
|
25
|
+
# 1. Load from file
|
|
26
|
+
if parameter_file:
|
|
27
|
+
self.load_parameters(parameter_file)
|
|
28
|
+
|
|
29
|
+
# 2. Override with env params (from PipelineClient)
|
|
30
|
+
import os
|
|
31
|
+
env_params_str = os.environ.get("ZEM_PARAMETERS")
|
|
32
|
+
if env_params_str:
|
|
33
|
+
try:
|
|
34
|
+
env_params = yaml.safe_load(env_params_str)
|
|
35
|
+
if isinstance(env_params, dict):
|
|
36
|
+
self._merge_parameters(env_params)
|
|
37
|
+
except Exception as e:
|
|
38
|
+
print(f"Error loading ZEM_PARAMETERS: {e}")
|
|
39
|
+
|
|
40
|
+
def load_parameters(self, file_path: str) -> Dict[str, Any]:
|
|
41
|
+
"""Load parameters from YAML file and merge them."""
|
|
42
|
+
path = Path(file_path)
|
|
43
|
+
if path.exists():
|
|
44
|
+
with open(path, "r") as f:
|
|
45
|
+
file_params = yaml.safe_load(f) or {}
|
|
46
|
+
self._merge_parameters(file_params)
|
|
47
|
+
return self.parameters
|
|
48
|
+
return {}
|
|
49
|
+
|
|
50
|
+
def _merge_parameters(self, new_params: Dict[str, Any]):
|
|
51
|
+
"""Deep merge and dot-notation expansion for parameters."""
|
|
52
|
+
for key, value in new_params.items():
|
|
53
|
+
if "." in key:
|
|
54
|
+
# Expand "tool.param" to {"tool": {"param": value}}
|
|
55
|
+
parts = key.split(".")
|
|
56
|
+
d = self.parameters
|
|
57
|
+
for part in parts[:-1]:
|
|
58
|
+
if part not in d or not isinstance(d[part], dict):
|
|
59
|
+
d[part] = {}
|
|
60
|
+
d = d[part]
|
|
61
|
+
|
|
62
|
+
last_part = parts[-1]
|
|
63
|
+
if isinstance(value, dict) and last_part in d and isinstance(d[last_part], dict):
|
|
64
|
+
self._deep_update(d[last_part], value)
|
|
65
|
+
else:
|
|
66
|
+
d[last_part] = value
|
|
67
|
+
else:
|
|
68
|
+
# Top level merge
|
|
69
|
+
if isinstance(value, dict) and key in self.parameters and isinstance(self.parameters[key], dict):
|
|
70
|
+
self._deep_update(self.parameters[key], value)
|
|
71
|
+
else:
|
|
72
|
+
self.parameters[key] = value
|
|
73
|
+
|
|
74
|
+
def _deep_update(self, target: Dict[str, Any], source: Dict[str, Any]):
|
|
75
|
+
"""Helper for deep dictionary update."""
|
|
76
|
+
for k, v in source.items():
|
|
77
|
+
if isinstance(v, dict) and k in target and isinstance(target[k], dict):
|
|
78
|
+
self._deep_update(target[k], v)
|
|
79
|
+
else:
|
|
80
|
+
target[k] = v
|
|
81
|
+
|
|
82
|
+
# Removed custom tool decorator to fix multiple values for argument 'name' error
|
|
83
|
+
# Inherit directly from FastMCP.tool
|
|
84
|
+
|
|
85
|
+
def get_data(self, data: Any) -> List[Dict[str, Any]]:
|
|
86
|
+
"""
|
|
87
|
+
Standardized way to get data, supporting both direct lists and file references.
|
|
88
|
+
"""
|
|
89
|
+
import os
|
|
90
|
+
from loguru import logger
|
|
91
|
+
|
|
92
|
+
logger.debug(f"Server {self.name}.get_data input type: {type(data)}")
|
|
93
|
+
|
|
94
|
+
if isinstance(data, list):
|
|
95
|
+
return data
|
|
96
|
+
|
|
97
|
+
if isinstance(data, dict):
|
|
98
|
+
if "path" in data:
|
|
99
|
+
path = data["path"]
|
|
100
|
+
ext = os.path.splitext(path)[1].lower()
|
|
101
|
+
|
|
102
|
+
logger.debug(f"Server {self.name}: Loading reference {path}")
|
|
103
|
+
if ext == ".jsonl":
|
|
104
|
+
import json
|
|
105
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
106
|
+
return [json.loads(line) for line in f if line.strip()]
|
|
107
|
+
elif ext == ".csv":
|
|
108
|
+
import pandas as pd
|
|
109
|
+
return pd.read_csv(path).to_dict(orient="records")
|
|
110
|
+
elif ext == ".parquet":
|
|
111
|
+
import pandas as pd
|
|
112
|
+
return pd.read_parquet(path).to_dict(orient="records")
|
|
113
|
+
else:
|
|
114
|
+
# Single record dictionary
|
|
115
|
+
logger.debug(f"Server {self.name}: Wrapping single dict in list")
|
|
116
|
+
return [data]
|
|
117
|
+
|
|
118
|
+
if isinstance(data, str):
|
|
119
|
+
# 1. URL/Cloud URI or Local Path
|
|
120
|
+
is_uri = data.startswith(("http", "s3://", "gs://"))
|
|
121
|
+
if is_uri or os.path.exists(data):
|
|
122
|
+
ext = os.path.splitext(data)[1].lower()
|
|
123
|
+
logger.debug(f"Server {self.name}: Loading data from {data}")
|
|
124
|
+
try:
|
|
125
|
+
import pandas as pd
|
|
126
|
+
if ext == ".parquet":
|
|
127
|
+
return pd.read_parquet(data).to_dict(orient="records")
|
|
128
|
+
elif ext == ".csv":
|
|
129
|
+
return pd.read_csv(data).to_dict(orient="records")
|
|
130
|
+
elif ext == ".jsonl":
|
|
131
|
+
if is_uri:
|
|
132
|
+
return pd.read_json(data, lines=True).to_dict(orient="records")
|
|
133
|
+
import json
|
|
134
|
+
with open(data, "r", encoding="utf-8") as f:
|
|
135
|
+
return [json.loads(line) for line in f if line.strip()]
|
|
136
|
+
except Exception as e:
|
|
137
|
+
logger.error(f"Error loading data from {data}: {e}")
|
|
138
|
+
|
|
139
|
+
# 2. Treat as raw text or JSON string
|
|
140
|
+
try:
|
|
141
|
+
import json
|
|
142
|
+
parsed = json.loads(data)
|
|
143
|
+
if isinstance(parsed, list): return parsed
|
|
144
|
+
if isinstance(parsed, dict): return [parsed]
|
|
145
|
+
except:
|
|
146
|
+
pass
|
|
147
|
+
return [{"text": data}]
|
|
148
|
+
|
|
149
|
+
logger.warning(f"Server {self.name}: Unrecognized data type {type(data)}")
|
|
150
|
+
return [{"raw": str(data)}]
|
|
151
|
+
|
|
152
|
+
def save_output(self, data: Any, format: str = "parquet") -> Dict[str, Any]:
|
|
153
|
+
"""
|
|
154
|
+
Saves output to a temporary file and returns a reference.
|
|
155
|
+
Prevents large data from being sent over JSON-RPC.
|
|
156
|
+
"""
|
|
157
|
+
import uuid
|
|
158
|
+
import os
|
|
159
|
+
from loguru import logger
|
|
160
|
+
|
|
161
|
+
base_dir = "/tmp/zem_artifacts"
|
|
162
|
+
os.makedirs(base_dir, exist_ok=True)
|
|
163
|
+
|
|
164
|
+
file_id = str(uuid.uuid4())[:8]
|
|
165
|
+
path = os.path.join(base_dir, f"{self.name}_output_{file_id}.{format}")
|
|
166
|
+
|
|
167
|
+
logger.info(f"Server {self.name}: Saving result to reference {path}")
|
|
168
|
+
|
|
169
|
+
if format == "parquet":
|
|
170
|
+
import pandas as pd
|
|
171
|
+
logger.debug(f"Server {self.name}: Converting to DataFrame, data type: {type(data)}")
|
|
172
|
+
try:
|
|
173
|
+
df = pd.DataFrame(data)
|
|
174
|
+
df.to_parquet(path, index=False)
|
|
175
|
+
except Exception as e:
|
|
176
|
+
logger.error(f"Server {self.name}: Failed to create DataFrame from {type(data)}: {e}")
|
|
177
|
+
raise
|
|
178
|
+
elif format == "jsonl":
|
|
179
|
+
import json
|
|
180
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
181
|
+
for item in data:
|
|
182
|
+
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
183
|
+
|
|
184
|
+
return {"path": path, "type": format, "size": os.path.getsize(path)}
|
|
185
|
+
|
|
186
|
+
def run(self, transport: str = "stdio"):
|
|
187
|
+
"""Run the server."""
|
|
188
|
+
super().run(transport=transport, show_banner=False)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Technical Parameters for DataJuicer Server
|
|
2
|
+
|
|
3
|
+
clean_content:
|
|
4
|
+
remove_html: true
|
|
5
|
+
remove_emojis: false
|
|
6
|
+
text_column: "text"
|
|
7
|
+
|
|
8
|
+
refining_filter:
|
|
9
|
+
min_len: 10
|
|
10
|
+
max_len: 100000
|
|
11
|
+
alphanumeric_ratio: 0.1
|
|
12
|
+
text_column: "text"
|
|
13
|
+
|
|
14
|
+
language_id:
|
|
15
|
+
expected_lang: "vi"
|
|
16
|
+
min_score: 0.8
|
|
17
|
+
text_column: "text"
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
from xfmr_zem.server import ZemServer
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
# Setup logging
|
|
9
|
+
logger.remove()
|
|
10
|
+
logger.add(sys.stderr, level="INFO")
|
|
11
|
+
|
|
12
|
+
server = ZemServer("data_juicer", parameter_file=os.path.join(os.path.dirname(__file__), "parameter.yaml"))
|
|
13
|
+
|
|
14
|
+
@server.tool()
|
|
15
|
+
def clean_content(
|
|
16
|
+
data: Any,
|
|
17
|
+
remove_html: bool = True,
|
|
18
|
+
remove_emojis: bool = False,
|
|
19
|
+
text_column: str = "text"
|
|
20
|
+
) -> Any:
|
|
21
|
+
"""
|
|
22
|
+
Flexible content cleaning tool using DataJuicer logic.
|
|
23
|
+
"""
|
|
24
|
+
items = server.get_data(data)
|
|
25
|
+
if not items: return []
|
|
26
|
+
|
|
27
|
+
logger.info(f"DataJuicer: Cleaning content (remove_html={remove_html}, remove_emojis={remove_emojis})")
|
|
28
|
+
|
|
29
|
+
for item in items:
|
|
30
|
+
if text_column not in item: continue
|
|
31
|
+
text = str(item[text_column])
|
|
32
|
+
|
|
33
|
+
if remove_html:
|
|
34
|
+
text = re.sub(r'<[^>]+>', '', text)
|
|
35
|
+
|
|
36
|
+
if remove_emojis:
|
|
37
|
+
# Targeted emoji removal: remove high-plane Unicode characters (likely emojis)
|
|
38
|
+
# while preserving Vietnamese diacritics and other standard Unicode text.
|
|
39
|
+
text = re.sub(r'[\U00010000-\U0010ffff]', '', text)
|
|
40
|
+
|
|
41
|
+
item[text_column] = text.strip()
|
|
42
|
+
|
|
43
|
+
return server.save_output(items)
|
|
44
|
+
|
|
45
|
+
@server.tool()
|
|
46
|
+
def refining_filter(
|
|
47
|
+
data: Any,
|
|
48
|
+
min_len: int = 10,
|
|
49
|
+
max_len: int = 100000,
|
|
50
|
+
alphanumeric_ratio: float = 0.1,
|
|
51
|
+
text_column: str = "text"
|
|
52
|
+
) -> Any:
|
|
53
|
+
"""
|
|
54
|
+
Filter items based on technical refining metrics.
|
|
55
|
+
"""
|
|
56
|
+
items = server.get_data(data)
|
|
57
|
+
if not items: return []
|
|
58
|
+
|
|
59
|
+
logger.info(f"DataJuicer: Refining filter (min_len={min_len}, alpha_ratio={alphanumeric_ratio})")
|
|
60
|
+
|
|
61
|
+
filtered = []
|
|
62
|
+
for item in items:
|
|
63
|
+
text = str(item.get(text_column, ""))
|
|
64
|
+
text_len = len(text)
|
|
65
|
+
|
|
66
|
+
if not (min_len <= text_len <= max_len):
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
# Alphanumeric ratio check
|
|
70
|
+
alnum_count = len([c for c in text if c.isalnum()])
|
|
71
|
+
if text_len > 0 and (alnum_count / text_len) < alphanumeric_ratio:
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
filtered.append(item)
|
|
75
|
+
|
|
76
|
+
return server.save_output(filtered)
|
|
77
|
+
|
|
78
|
+
@server.tool()
|
|
79
|
+
def language_id(
|
|
80
|
+
data: Any,
|
|
81
|
+
expected_lang: str = "vi",
|
|
82
|
+
min_score: float = 0.8,
|
|
83
|
+
text_column: str = "text"
|
|
84
|
+
) -> Any:
|
|
85
|
+
"""
|
|
86
|
+
Heuristic language identification tool.
|
|
87
|
+
"""
|
|
88
|
+
items = server.get_data(data)
|
|
89
|
+
if not items: return []
|
|
90
|
+
|
|
91
|
+
logger.info(f"DataJuicer: Filtering for language '{expected_lang}'")
|
|
92
|
+
|
|
93
|
+
# Heuristic for Vietnamese
|
|
94
|
+
vi_keywords = {'và', 'của', 'là', 'có', 'được', 'trong', 'cho', 'này', 'với', 'các'}
|
|
95
|
+
|
|
96
|
+
filtered = []
|
|
97
|
+
for item in items:
|
|
98
|
+
text = str(item.get(text_column, "")).lower()
|
|
99
|
+
words = set(text.split())
|
|
100
|
+
|
|
101
|
+
if expected_lang == "vi":
|
|
102
|
+
matches = len(words & vi_keywords)
|
|
103
|
+
# Heuristic score: ratio of keywords found (simplified)
|
|
104
|
+
if matches >= 2:
|
|
105
|
+
filtered.append(item)
|
|
106
|
+
else:
|
|
107
|
+
# Fallback for other languages
|
|
108
|
+
filtered.append(item)
|
|
109
|
+
|
|
110
|
+
return server.save_output(filtered)
|
|
111
|
+
|
|
112
|
+
if __name__ == "__main__":
|
|
113
|
+
server.run()
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# Technical Parameters for Instruction Generation Server
|
|
2
|
+
# These are default technical values. Override them in your pipeline YAML.
|
|
3
|
+
|
|
4
|
+
generate_qa_pairs:
|
|
5
|
+
base_url: "http://localhost:8000/v1"
|
|
6
|
+
model: "default"
|
|
7
|
+
num_pairs: 3
|
|
8
|
+
text_column: "text"
|
|
9
|
+
prompt_template: "Dựa trên văn bản pháp luật sau, hãy tạo {num_pairs} cặp Câu hỏi và Trả lời chi tiết. Trả về định dạng JSON list: [{{'q': '...', 'a': '...'}}]. Văn bản: {text}"
|
|
10
|
+
|
|
11
|
+
complexity_scorer:
|
|
12
|
+
text_column: "text"
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import json
|
|
4
|
+
import httpx
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
from xfmr_zem.server import ZemServer
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
# Setup logging
|
|
10
|
+
logger.remove()
|
|
11
|
+
logger.add(sys.stderr, level="INFO")
|
|
12
|
+
|
|
13
|
+
server = ZemServer("instruction", parameter_file=os.path.join(os.path.dirname(__file__), "parameter.yaml"))
|
|
14
|
+
|
|
15
|
+
@server.tool()
|
|
16
|
+
def generate_qa_pairs(
|
|
17
|
+
data: Any = None,
|
|
18
|
+
base_url: str = "http://localhost:8000/v1",
|
|
19
|
+
model: str = "default",
|
|
20
|
+
num_pairs: int = 3,
|
|
21
|
+
text_column: str = "text",
|
|
22
|
+
prompt_template: str = "Dựa trên văn bản pháp luật sau, hãy tạo {num_pairs} cặp Câu hỏi và Trả lời chi tiết. Trả về định dạng JSON list: [{{'q': '...', 'a': '...'}}]. Văn bản: {text}"
|
|
23
|
+
) -> Any:
|
|
24
|
+
raw_data = server.get_data(data)
|
|
25
|
+
if not raw_data: return []
|
|
26
|
+
|
|
27
|
+
# Handle wrapped data from previous steps
|
|
28
|
+
if isinstance(raw_data, dict) and 'data' in raw_data:
|
|
29
|
+
actual_items = raw_data['data']
|
|
30
|
+
else:
|
|
31
|
+
actual_items = raw_data
|
|
32
|
+
|
|
33
|
+
logger.info(f"InstructionGen: Connecting to vLLM at {base_url} for {len(actual_items)} items")
|
|
34
|
+
|
|
35
|
+
processed_items = []
|
|
36
|
+
with httpx.Client(timeout=60.0) as client:
|
|
37
|
+
for item in actual_items:
|
|
38
|
+
text = str(item.get(text_column, ""))
|
|
39
|
+
if not text: continue
|
|
40
|
+
|
|
41
|
+
prompt = prompt_template.format(num_pairs=num_pairs, text=text)
|
|
42
|
+
try:
|
|
43
|
+
response = client.post(
|
|
44
|
+
f"{base_url}/chat/completions",
|
|
45
|
+
json={
|
|
46
|
+
"model": model,
|
|
47
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
48
|
+
"temperature": 0.3
|
|
49
|
+
}
|
|
50
|
+
)
|
|
51
|
+
if response.status_code == 200:
|
|
52
|
+
raw_content = response.json()["choices"][0]["message"]["content"]
|
|
53
|
+
start = raw_content.find('[')
|
|
54
|
+
end = raw_content.rfind(']') + 1
|
|
55
|
+
if start != -1 and end != -1:
|
|
56
|
+
item["instructions"] = json.loads(raw_content[start:end])
|
|
57
|
+
else:
|
|
58
|
+
item["instructions"] = []
|
|
59
|
+
except Exception:
|
|
60
|
+
item["instructions"] = []
|
|
61
|
+
processed_items.append(item)
|
|
62
|
+
|
|
63
|
+
return server.save_output(processed_items)
|
|
64
|
+
|
|
65
|
+
@server.tool()
|
|
66
|
+
def complexity_scorer(
|
|
67
|
+
data: Any = None,
|
|
68
|
+
text_column: str = "text"
|
|
69
|
+
) -> Any:
|
|
70
|
+
raw_data = server.get_data(data)
|
|
71
|
+
if not raw_data: return []
|
|
72
|
+
|
|
73
|
+
if isinstance(raw_data, dict) and 'data' in raw_data:
|
|
74
|
+
actual_items = raw_data['data']
|
|
75
|
+
else:
|
|
76
|
+
actual_items = raw_data
|
|
77
|
+
|
|
78
|
+
logger.info(f"InstructionGen: Scoring complexity for {len(actual_items)} items")
|
|
79
|
+
|
|
80
|
+
for item in actual_items:
|
|
81
|
+
text = str(item.get(text_column, ""))
|
|
82
|
+
words = text.split()
|
|
83
|
+
unique_words = set(words)
|
|
84
|
+
score = min(1.0, (len(unique_words) / 100) * (len(words) / 500))
|
|
85
|
+
item["complexity_score"] = round(float(score), 2)
|
|
86
|
+
|
|
87
|
+
return server.save_output(actual_items)
|
|
88
|
+
|
|
89
|
+
if __name__ == "__main__":
|
|
90
|
+
server.run()
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
from xfmr_zem.server import ZemServer
|
|
6
|
+
from loguru import logger
|
|
7
|
+
import sys
|
|
8
|
+
|
|
9
|
+
# Initialize server
|
|
10
|
+
server = ZemServer("io", parameter_file=os.path.join(os.path.dirname(__file__), "parameter.yaml"))
|
|
11
|
+
|
|
12
|
+
@server.tool()
|
|
13
|
+
def load_jsonl(path: str, return_reference: bool = False) -> Any:
|
|
14
|
+
"""
|
|
15
|
+
Load data from a JSONL file. If return_reference is True, returns metadata instead of data.
|
|
16
|
+
"""
|
|
17
|
+
logger.info(f"IO: Loading JSONL from {path}")
|
|
18
|
+
if not os.path.exists(path):
|
|
19
|
+
raise FileNotFoundError(f"File not found: {path}")
|
|
20
|
+
|
|
21
|
+
if return_reference:
|
|
22
|
+
return {"path": path, "type": "jsonl", "size": os.path.getsize(path)}
|
|
23
|
+
|
|
24
|
+
data = []
|
|
25
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
26
|
+
for line in f:
|
|
27
|
+
if line.strip():
|
|
28
|
+
data.append(json.loads(line))
|
|
29
|
+
return data
|
|
30
|
+
|
|
31
|
+
@server.tool()
|
|
32
|
+
def save_jsonl(data: Any, path: str) -> Dict[str, Any]:
|
|
33
|
+
"""Save data (List or Path reference) to JSONL."""
|
|
34
|
+
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
|
35
|
+
|
|
36
|
+
# If data is a reference (dict with path), we might want to move/copy it
|
|
37
|
+
# but for now, assume data is a List if we are in save_jsonl
|
|
38
|
+
if isinstance(data, dict) and "path" in data:
|
|
39
|
+
import shutil
|
|
40
|
+
shutil.copy(data["path"], path)
|
|
41
|
+
return {"status": "copied", "path": path}
|
|
42
|
+
|
|
43
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
44
|
+
for item in data:
|
|
45
|
+
f.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
46
|
+
return {"status": "success", "path": path, "count": len(data)}
|
|
47
|
+
|
|
48
|
+
@server.tool()
|
|
49
|
+
def load_parquet(path: str, return_reference: bool = True) -> Any:
|
|
50
|
+
"""
|
|
51
|
+
Load data from a Parquet file. Defaults to return_reference=True for big data.
|
|
52
|
+
"""
|
|
53
|
+
logger.info(f"IO: Handling Parquet from {path}")
|
|
54
|
+
if not os.path.exists(path):
|
|
55
|
+
raise FileNotFoundError(f"File not found: {path}")
|
|
56
|
+
|
|
57
|
+
if return_reference:
|
|
58
|
+
return {"path": path, "type": "parquet", "size": os.path.getsize(path)}
|
|
59
|
+
|
|
60
|
+
df = pd.read_parquet(path)
|
|
61
|
+
return df.to_dict(orient="records")
|
|
62
|
+
|
|
63
|
+
@server.tool()
|
|
64
|
+
def save_parquet(data: Any, path: str) -> Dict[str, Any]:
|
|
65
|
+
"""
|
|
66
|
+
Save data to a Parquet file.
|
|
67
|
+
"""
|
|
68
|
+
logger.info(f"IO: Saving Parquet to {path}")
|
|
69
|
+
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
|
70
|
+
|
|
71
|
+
if isinstance(data, dict) and "path" in data:
|
|
72
|
+
src = data["path"]
|
|
73
|
+
logger.info(f"IO: Copying reference from {src} to {path}")
|
|
74
|
+
import shutil
|
|
75
|
+
shutil.copy(src, path)
|
|
76
|
+
return {"status": "copied", "path": path}
|
|
77
|
+
|
|
78
|
+
logger.info("IO: Saving raw data to Parquet")
|
|
79
|
+
df = pd.DataFrame(data)
|
|
80
|
+
df.to_parquet(path, index=False)
|
|
81
|
+
return {"status": "success", "path": path, "count": len(data)}
|
|
82
|
+
|
|
83
|
+
@server.tool()
|
|
84
|
+
def scan_directory(path: str, pattern: str = "*") -> List[Dict[str, Any]]:
|
|
85
|
+
"""
|
|
86
|
+
Scan a directory and return a list of file references.
|
|
87
|
+
Useful for parallelizing big data processing.
|
|
88
|
+
"""
|
|
89
|
+
import glob
|
|
90
|
+
logger.info(f"IO: Scanning directory {path} with pattern {pattern}")
|
|
91
|
+
files = glob.glob(os.path.join(path, pattern))
|
|
92
|
+
return [{"path": f, "type": "file_ref"} for f in files]
|
|
93
|
+
|
|
94
|
+
if __name__ == "__main__":
|
|
95
|
+
server.run()
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from xfmr_zem.server import ZemServer
|
|
2
|
+
from typing import Any, List, Optional
|
|
3
|
+
import os
|
|
4
|
+
import json
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
mcp = ZemServer("LLM-Curation")
|
|
8
|
+
|
|
9
|
+
def _call_llm(prompt: str, provider: str = "ollama", model: Optional[str] = None) -> str:
|
|
10
|
+
"""Helper to route LLM calls."""
|
|
11
|
+
if provider == "ollama":
|
|
12
|
+
model = model or "llama3"
|
|
13
|
+
try:
|
|
14
|
+
url = os.environ.get("OLLAMA_URL", "http://localhost:11434/api/generate")
|
|
15
|
+
response = requests.post(url, json={"model": model, "prompt": prompt, "stream": False})
|
|
16
|
+
return response.json().get("response", "")
|
|
17
|
+
except Exception:
|
|
18
|
+
return f"[Ollama Error] Could not connect to {model}. Fallback: processed {prompt[:20]}..."
|
|
19
|
+
elif provider == "openai":
|
|
20
|
+
api_key = os.environ.get("OPENAI_API_KEY")
|
|
21
|
+
if not api_key: return "[OpenAI Error] Missing API Key."
|
|
22
|
+
# Placeholder for real openai client call
|
|
23
|
+
return f"[OpenAI Mock] Classification/Response for: {prompt[:20]}..."
|
|
24
|
+
return f"[Mock] {prompt}"
|
|
25
|
+
|
|
26
|
+
@mcp.tool()
|
|
27
|
+
def mask_pii(data: Any, provider: str = "ollama") -> List[Any]:
|
|
28
|
+
"""Smart PII masking using LLM."""
|
|
29
|
+
dataset = mcp.get_data(data)
|
|
30
|
+
for item in dataset:
|
|
31
|
+
if "text" in item:
|
|
32
|
+
prompt = f"Remove all PII from this text and return ONLY the cleaned text: {item['text']}"
|
|
33
|
+
item["text"] = _call_llm(prompt, provider=provider)
|
|
34
|
+
return dataset
|
|
35
|
+
|
|
36
|
+
@mcp.tool()
|
|
37
|
+
def classify_domain(data: Any, categories: List[str] = ["Tech", "Finance", "Legal"], provider: str = "ollama") -> List[Any]:
|
|
38
|
+
"""Classify data domain using LLM."""
|
|
39
|
+
dataset = mcp.get_data(data)
|
|
40
|
+
for item in dataset:
|
|
41
|
+
if "text" in item:
|
|
42
|
+
prompt = f"Classify this text into one of {categories}. Return ONLY the category name: {item['text']}"
|
|
43
|
+
item["domain"] = _call_llm(prompt, provider=provider).strip()
|
|
44
|
+
return dataset
|
|
45
|
+
|
|
46
|
+
if __name__ == "__main__":
|
|
47
|
+
mcp.run()
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Technical Parameters for Nemo Curator Server
|
|
2
|
+
# These are default technical values. Override them in your pipeline YAML.
|
|
3
|
+
|
|
4
|
+
normalize:
|
|
5
|
+
normalization: "NFC"
|
|
6
|
+
text_column: "text"
|
|
7
|
+
cleanup_patterns: []
|
|
8
|
+
|
|
9
|
+
quality_filter:
|
|
10
|
+
min_words: 50
|
|
11
|
+
max_non_alpha_ratio: 0.25
|
|
12
|
+
text_column: "text"
|
|
13
|
+
|
|
14
|
+
deduplicate:
|
|
15
|
+
text_column: "text"
|
|
16
|
+
method: "exact"
|
|
17
|
+
threshold: 0.85
|