xfmr-zem 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xfmr_zem/server.py ADDED
@@ -0,0 +1,188 @@
1
+
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+ import yaml
4
+ from pathlib import Path
5
+ from fastmcp import FastMCP
6
+ import inspect
7
+
8
+ class ZemServer(FastMCP):
9
+ """
10
+ Base class for Zem MCP Servers.
11
+ Extends FastMCP to support parameter loading and standardized tool registration.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ name: str,
17
+ Dependencies: Optional[List[str]] = None,
18
+ parameter_file: Optional[str] = None,
19
+ **kwargs
20
+ ):
21
+ super().__init__(name=name, **kwargs)
22
+ self.parameter_file = parameter_file
23
+ self.parameters = {}
24
+
25
+ # 1. Load from file
26
+ if parameter_file:
27
+ self.load_parameters(parameter_file)
28
+
29
+ # 2. Override with env params (from PipelineClient)
30
+ import os
31
+ env_params_str = os.environ.get("ZEM_PARAMETERS")
32
+ if env_params_str:
33
+ try:
34
+ env_params = yaml.safe_load(env_params_str)
35
+ if isinstance(env_params, dict):
36
+ self._merge_parameters(env_params)
37
+ except Exception as e:
38
+ print(f"Error loading ZEM_PARAMETERS: {e}")
39
+
40
+ def load_parameters(self, file_path: str) -> Dict[str, Any]:
41
+ """Load parameters from YAML file and merge them."""
42
+ path = Path(file_path)
43
+ if path.exists():
44
+ with open(path, "r") as f:
45
+ file_params = yaml.safe_load(f) or {}
46
+ self._merge_parameters(file_params)
47
+ return self.parameters
48
+ return {}
49
+
50
+ def _merge_parameters(self, new_params: Dict[str, Any]):
51
+ """Deep merge and dot-notation expansion for parameters."""
52
+ for key, value in new_params.items():
53
+ if "." in key:
54
+ # Expand "tool.param" to {"tool": {"param": value}}
55
+ parts = key.split(".")
56
+ d = self.parameters
57
+ for part in parts[:-1]:
58
+ if part not in d or not isinstance(d[part], dict):
59
+ d[part] = {}
60
+ d = d[part]
61
+
62
+ last_part = parts[-1]
63
+ if isinstance(value, dict) and last_part in d and isinstance(d[last_part], dict):
64
+ self._deep_update(d[last_part], value)
65
+ else:
66
+ d[last_part] = value
67
+ else:
68
+ # Top level merge
69
+ if isinstance(value, dict) and key in self.parameters and isinstance(self.parameters[key], dict):
70
+ self._deep_update(self.parameters[key], value)
71
+ else:
72
+ self.parameters[key] = value
73
+
74
+ def _deep_update(self, target: Dict[str, Any], source: Dict[str, Any]):
75
+ """Helper for deep dictionary update."""
76
+ for k, v in source.items():
77
+ if isinstance(v, dict) and k in target and isinstance(target[k], dict):
78
+ self._deep_update(target[k], v)
79
+ else:
80
+ target[k] = v
81
+
82
+ # Removed custom tool decorator to fix multiple values for argument 'name' error
83
+ # Inherit directly from FastMCP.tool
84
+
85
+ def get_data(self, data: Any) -> List[Dict[str, Any]]:
86
+ """
87
+ Standardized way to get data, supporting both direct lists and file references.
88
+ """
89
+ import os
90
+ from loguru import logger
91
+
92
+ logger.debug(f"Server {self.name}.get_data input type: {type(data)}")
93
+
94
+ if isinstance(data, list):
95
+ return data
96
+
97
+ if isinstance(data, dict):
98
+ if "path" in data:
99
+ path = data["path"]
100
+ ext = os.path.splitext(path)[1].lower()
101
+
102
+ logger.debug(f"Server {self.name}: Loading reference {path}")
103
+ if ext == ".jsonl":
104
+ import json
105
+ with open(path, "r", encoding="utf-8") as f:
106
+ return [json.loads(line) for line in f if line.strip()]
107
+ elif ext == ".csv":
108
+ import pandas as pd
109
+ return pd.read_csv(path).to_dict(orient="records")
110
+ elif ext == ".parquet":
111
+ import pandas as pd
112
+ return pd.read_parquet(path).to_dict(orient="records")
113
+ else:
114
+ # Single record dictionary
115
+ logger.debug(f"Server {self.name}: Wrapping single dict in list")
116
+ return [data]
117
+
118
+ if isinstance(data, str):
119
+ # 1. URL/Cloud URI or Local Path
120
+ is_uri = data.startswith(("http", "s3://", "gs://"))
121
+ if is_uri or os.path.exists(data):
122
+ ext = os.path.splitext(data)[1].lower()
123
+ logger.debug(f"Server {self.name}: Loading data from {data}")
124
+ try:
125
+ import pandas as pd
126
+ if ext == ".parquet":
127
+ return pd.read_parquet(data).to_dict(orient="records")
128
+ elif ext == ".csv":
129
+ return pd.read_csv(data).to_dict(orient="records")
130
+ elif ext == ".jsonl":
131
+ if is_uri:
132
+ return pd.read_json(data, lines=True).to_dict(orient="records")
133
+ import json
134
+ with open(data, "r", encoding="utf-8") as f:
135
+ return [json.loads(line) for line in f if line.strip()]
136
+ except Exception as e:
137
+ logger.error(f"Error loading data from {data}: {e}")
138
+
139
+ # 2. Treat as raw text or JSON string
140
+ try:
141
+ import json
142
+ parsed = json.loads(data)
143
+ if isinstance(parsed, list): return parsed
144
+ if isinstance(parsed, dict): return [parsed]
145
+ except:
146
+ pass
147
+ return [{"text": data}]
148
+
149
+ logger.warning(f"Server {self.name}: Unrecognized data type {type(data)}")
150
+ return [{"raw": str(data)}]
151
+
152
+ def save_output(self, data: Any, format: str = "parquet") -> Dict[str, Any]:
153
+ """
154
+ Saves output to a temporary file and returns a reference.
155
+ Prevents large data from being sent over JSON-RPC.
156
+ """
157
+ import uuid
158
+ import os
159
+ from loguru import logger
160
+
161
+ base_dir = "/tmp/zem_artifacts"
162
+ os.makedirs(base_dir, exist_ok=True)
163
+
164
+ file_id = str(uuid.uuid4())[:8]
165
+ path = os.path.join(base_dir, f"{self.name}_output_{file_id}.{format}")
166
+
167
+ logger.info(f"Server {self.name}: Saving result to reference {path}")
168
+
169
+ if format == "parquet":
170
+ import pandas as pd
171
+ logger.debug(f"Server {self.name}: Converting to DataFrame, data type: {type(data)}")
172
+ try:
173
+ df = pd.DataFrame(data)
174
+ df.to_parquet(path, index=False)
175
+ except Exception as e:
176
+ logger.error(f"Server {self.name}: Failed to create DataFrame from {type(data)}: {e}")
177
+ raise
178
+ elif format == "jsonl":
179
+ import json
180
+ with open(path, "w", encoding="utf-8") as f:
181
+ for item in data:
182
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
183
+
184
+ return {"path": path, "type": format, "size": os.path.getsize(path)}
185
+
186
+ def run(self, transport: str = "stdio"):
187
+ """Run the server."""
188
+ super().run(transport=transport, show_banner=False)
@@ -0,0 +1,17 @@
1
+ # Technical Parameters for DataJuicer Server
2
+
3
+ clean_content:
4
+ remove_html: true
5
+ remove_emojis: false
6
+ text_column: "text"
7
+
8
+ refining_filter:
9
+ min_len: 10
10
+ max_len: 100000
11
+ alphanumeric_ratio: 0.1
12
+ text_column: "text"
13
+
14
+ language_id:
15
+ expected_lang: "vi"
16
+ min_score: 0.8
17
+ text_column: "text"
@@ -0,0 +1,113 @@
1
+ import os
2
+ import sys
3
+ import re
4
+ from typing import Any, Dict, List, Optional
5
+ from xfmr_zem.server import ZemServer
6
+ from loguru import logger
7
+
8
+ # Setup logging
9
+ logger.remove()
10
+ logger.add(sys.stderr, level="INFO")
11
+
12
+ server = ZemServer("data_juicer", parameter_file=os.path.join(os.path.dirname(__file__), "parameter.yaml"))
13
+
14
+ @server.tool()
15
+ def clean_content(
16
+ data: Any,
17
+ remove_html: bool = True,
18
+ remove_emojis: bool = False,
19
+ text_column: str = "text"
20
+ ) -> Any:
21
+ """
22
+ Flexible content cleaning tool using DataJuicer logic.
23
+ """
24
+ items = server.get_data(data)
25
+ if not items: return []
26
+
27
+ logger.info(f"DataJuicer: Cleaning content (remove_html={remove_html}, remove_emojis={remove_emojis})")
28
+
29
+ for item in items:
30
+ if text_column not in item: continue
31
+ text = str(item[text_column])
32
+
33
+ if remove_html:
34
+ text = re.sub(r'<[^>]+>', '', text)
35
+
36
+ if remove_emojis:
37
+ # Targeted emoji removal: remove high-plane Unicode characters (likely emojis)
38
+ # while preserving Vietnamese diacritics and other standard Unicode text.
39
+ text = re.sub(r'[\U00010000-\U0010ffff]', '', text)
40
+
41
+ item[text_column] = text.strip()
42
+
43
+ return server.save_output(items)
44
+
45
+ @server.tool()
46
+ def refining_filter(
47
+ data: Any,
48
+ min_len: int = 10,
49
+ max_len: int = 100000,
50
+ alphanumeric_ratio: float = 0.1,
51
+ text_column: str = "text"
52
+ ) -> Any:
53
+ """
54
+ Filter items based on technical refining metrics.
55
+ """
56
+ items = server.get_data(data)
57
+ if not items: return []
58
+
59
+ logger.info(f"DataJuicer: Refining filter (min_len={min_len}, alpha_ratio={alphanumeric_ratio})")
60
+
61
+ filtered = []
62
+ for item in items:
63
+ text = str(item.get(text_column, ""))
64
+ text_len = len(text)
65
+
66
+ if not (min_len <= text_len <= max_len):
67
+ continue
68
+
69
+ # Alphanumeric ratio check
70
+ alnum_count = len([c for c in text if c.isalnum()])
71
+ if text_len > 0 and (alnum_count / text_len) < alphanumeric_ratio:
72
+ continue
73
+
74
+ filtered.append(item)
75
+
76
+ return server.save_output(filtered)
77
+
78
+ @server.tool()
79
+ def language_id(
80
+ data: Any,
81
+ expected_lang: str = "vi",
82
+ min_score: float = 0.8,
83
+ text_column: str = "text"
84
+ ) -> Any:
85
+ """
86
+ Heuristic language identification tool.
87
+ """
88
+ items = server.get_data(data)
89
+ if not items: return []
90
+
91
+ logger.info(f"DataJuicer: Filtering for language '{expected_lang}'")
92
+
93
+ # Heuristic for Vietnamese
94
+ vi_keywords = {'và', 'của', 'là', 'có', 'được', 'trong', 'cho', 'này', 'với', 'các'}
95
+
96
+ filtered = []
97
+ for item in items:
98
+ text = str(item.get(text_column, "")).lower()
99
+ words = set(text.split())
100
+
101
+ if expected_lang == "vi":
102
+ matches = len(words & vi_keywords)
103
+ # Heuristic score: ratio of keywords found (simplified)
104
+ if matches >= 2:
105
+ filtered.append(item)
106
+ else:
107
+ # Fallback for other languages
108
+ filtered.append(item)
109
+
110
+ return server.save_output(filtered)
111
+
112
+ if __name__ == "__main__":
113
+ server.run()
@@ -0,0 +1,12 @@
1
+ # Technical Parameters for Instruction Generation Server
2
+ # These are default technical values. Override them in your pipeline YAML.
3
+
4
+ generate_qa_pairs:
5
+ base_url: "http://localhost:8000/v1"
6
+ model: "default"
7
+ num_pairs: 3
8
+ text_column: "text"
9
+ prompt_template: "Dựa trên văn bản pháp luật sau, hãy tạo {num_pairs} cặp Câu hỏi và Trả lời chi tiết. Trả về định dạng JSON list: [{{'q': '...', 'a': '...'}}]. Văn bản: {text}"
10
+
11
+ complexity_scorer:
12
+ text_column: "text"
@@ -0,0 +1,90 @@
1
+ import os
2
+ import sys
3
+ import json
4
+ import httpx
5
+ from typing import Any, Dict, List, Optional
6
+ from xfmr_zem.server import ZemServer
7
+ from loguru import logger
8
+
9
+ # Setup logging
10
+ logger.remove()
11
+ logger.add(sys.stderr, level="INFO")
12
+
13
+ server = ZemServer("instruction", parameter_file=os.path.join(os.path.dirname(__file__), "parameter.yaml"))
14
+
15
+ @server.tool()
16
+ def generate_qa_pairs(
17
+ data: Any = None,
18
+ base_url: str = "http://localhost:8000/v1",
19
+ model: str = "default",
20
+ num_pairs: int = 3,
21
+ text_column: str = "text",
22
+ prompt_template: str = "Dựa trên văn bản pháp luật sau, hãy tạo {num_pairs} cặp Câu hỏi và Trả lời chi tiết. Trả về định dạng JSON list: [{{'q': '...', 'a': '...'}}]. Văn bản: {text}"
23
+ ) -> Any:
24
+ raw_data = server.get_data(data)
25
+ if not raw_data: return []
26
+
27
+ # Handle wrapped data from previous steps
28
+ if isinstance(raw_data, dict) and 'data' in raw_data:
29
+ actual_items = raw_data['data']
30
+ else:
31
+ actual_items = raw_data
32
+
33
+ logger.info(f"InstructionGen: Connecting to vLLM at {base_url} for {len(actual_items)} items")
34
+
35
+ processed_items = []
36
+ with httpx.Client(timeout=60.0) as client:
37
+ for item in actual_items:
38
+ text = str(item.get(text_column, ""))
39
+ if not text: continue
40
+
41
+ prompt = prompt_template.format(num_pairs=num_pairs, text=text)
42
+ try:
43
+ response = client.post(
44
+ f"{base_url}/chat/completions",
45
+ json={
46
+ "model": model,
47
+ "messages": [{"role": "user", "content": prompt}],
48
+ "temperature": 0.3
49
+ }
50
+ )
51
+ if response.status_code == 200:
52
+ raw_content = response.json()["choices"][0]["message"]["content"]
53
+ start = raw_content.find('[')
54
+ end = raw_content.rfind(']') + 1
55
+ if start != -1 and end != -1:
56
+ item["instructions"] = json.loads(raw_content[start:end])
57
+ else:
58
+ item["instructions"] = []
59
+ except Exception:
60
+ item["instructions"] = []
61
+ processed_items.append(item)
62
+
63
+ return server.save_output(processed_items)
64
+
65
+ @server.tool()
66
+ def complexity_scorer(
67
+ data: Any = None,
68
+ text_column: str = "text"
69
+ ) -> Any:
70
+ raw_data = server.get_data(data)
71
+ if not raw_data: return []
72
+
73
+ if isinstance(raw_data, dict) and 'data' in raw_data:
74
+ actual_items = raw_data['data']
75
+ else:
76
+ actual_items = raw_data
77
+
78
+ logger.info(f"InstructionGen: Scoring complexity for {len(actual_items)} items")
79
+
80
+ for item in actual_items:
81
+ text = str(item.get(text_column, ""))
82
+ words = text.split()
83
+ unique_words = set(words)
84
+ score = min(1.0, (len(unique_words) / 100) * (len(words) / 500))
85
+ item["complexity_score"] = round(float(score), 2)
86
+
87
+ return server.save_output(actual_items)
88
+
89
+ if __name__ == "__main__":
90
+ server.run()
@@ -0,0 +1,10 @@
1
+ # IO Server Configuration
2
+ # Default settings for loading and saving data
3
+
4
+ load_settings:
5
+ chunk_size: 1000
6
+ encoding: "utf-8"
7
+
8
+ save_settings:
9
+ overwrite: true
10
+ encoding: "utf-8"
@@ -0,0 +1,95 @@
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ from typing import Any, Dict, List, Optional
5
+ from xfmr_zem.server import ZemServer
6
+ from loguru import logger
7
+ import sys
8
+
9
+ # Initialize server
10
+ server = ZemServer("io", parameter_file=os.path.join(os.path.dirname(__file__), "parameter.yaml"))
11
+
12
+ @server.tool()
13
+ def load_jsonl(path: str, return_reference: bool = False) -> Any:
14
+ """
15
+ Load data from a JSONL file. If return_reference is True, returns metadata instead of data.
16
+ """
17
+ logger.info(f"IO: Loading JSONL from {path}")
18
+ if not os.path.exists(path):
19
+ raise FileNotFoundError(f"File not found: {path}")
20
+
21
+ if return_reference:
22
+ return {"path": path, "type": "jsonl", "size": os.path.getsize(path)}
23
+
24
+ data = []
25
+ with open(path, "r", encoding="utf-8") as f:
26
+ for line in f:
27
+ if line.strip():
28
+ data.append(json.loads(line))
29
+ return data
30
+
31
+ @server.tool()
32
+ def save_jsonl(data: Any, path: str) -> Dict[str, Any]:
33
+ """Save data (List or Path reference) to JSONL."""
34
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
35
+
36
+ # If data is a reference (dict with path), we might want to move/copy it
37
+ # but for now, assume data is a List if we are in save_jsonl
38
+ if isinstance(data, dict) and "path" in data:
39
+ import shutil
40
+ shutil.copy(data["path"], path)
41
+ return {"status": "copied", "path": path}
42
+
43
+ with open(path, "w", encoding="utf-8") as f:
44
+ for item in data:
45
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
46
+ return {"status": "success", "path": path, "count": len(data)}
47
+
48
+ @server.tool()
49
+ def load_parquet(path: str, return_reference: bool = True) -> Any:
50
+ """
51
+ Load data from a Parquet file. Defaults to return_reference=True for big data.
52
+ """
53
+ logger.info(f"IO: Handling Parquet from {path}")
54
+ if not os.path.exists(path):
55
+ raise FileNotFoundError(f"File not found: {path}")
56
+
57
+ if return_reference:
58
+ return {"path": path, "type": "parquet", "size": os.path.getsize(path)}
59
+
60
+ df = pd.read_parquet(path)
61
+ return df.to_dict(orient="records")
62
+
63
+ @server.tool()
64
+ def save_parquet(data: Any, path: str) -> Dict[str, Any]:
65
+ """
66
+ Save data to a Parquet file.
67
+ """
68
+ logger.info(f"IO: Saving Parquet to {path}")
69
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
70
+
71
+ if isinstance(data, dict) and "path" in data:
72
+ src = data["path"]
73
+ logger.info(f"IO: Copying reference from {src} to {path}")
74
+ import shutil
75
+ shutil.copy(src, path)
76
+ return {"status": "copied", "path": path}
77
+
78
+ logger.info("IO: Saving raw data to Parquet")
79
+ df = pd.DataFrame(data)
80
+ df.to_parquet(path, index=False)
81
+ return {"status": "success", "path": path, "count": len(data)}
82
+
83
+ @server.tool()
84
+ def scan_directory(path: str, pattern: str = "*") -> List[Dict[str, Any]]:
85
+ """
86
+ Scan a directory and return a list of file references.
87
+ Useful for parallelizing big data processing.
88
+ """
89
+ import glob
90
+ logger.info(f"IO: Scanning directory {path} with pattern {pattern}")
91
+ files = glob.glob(os.path.join(path, pattern))
92
+ return [{"path": f, "type": "file_ref"} for f in files]
93
+
94
+ if __name__ == "__main__":
95
+ server.run()
@@ -0,0 +1,47 @@
1
+ from xfmr_zem.server import ZemServer
2
+ from typing import Any, List, Optional
3
+ import os
4
+ import json
5
+ import requests
6
+
7
+ mcp = ZemServer("LLM-Curation")
8
+
9
+ def _call_llm(prompt: str, provider: str = "ollama", model: Optional[str] = None) -> str:
10
+ """Helper to route LLM calls."""
11
+ if provider == "ollama":
12
+ model = model or "llama3"
13
+ try:
14
+ url = os.environ.get("OLLAMA_URL", "http://localhost:11434/api/generate")
15
+ response = requests.post(url, json={"model": model, "prompt": prompt, "stream": False})
16
+ return response.json().get("response", "")
17
+ except Exception:
18
+ return f"[Ollama Error] Could not connect to {model}. Fallback: processed {prompt[:20]}..."
19
+ elif provider == "openai":
20
+ api_key = os.environ.get("OPENAI_API_KEY")
21
+ if not api_key: return "[OpenAI Error] Missing API Key."
22
+ # Placeholder for real openai client call
23
+ return f"[OpenAI Mock] Classification/Response for: {prompt[:20]}..."
24
+ return f"[Mock] {prompt}"
25
+
26
+ @mcp.tool()
27
+ def mask_pii(data: Any, provider: str = "ollama") -> List[Any]:
28
+ """Smart PII masking using LLM."""
29
+ dataset = mcp.get_data(data)
30
+ for item in dataset:
31
+ if "text" in item:
32
+ prompt = f"Remove all PII from this text and return ONLY the cleaned text: {item['text']}"
33
+ item["text"] = _call_llm(prompt, provider=provider)
34
+ return dataset
35
+
36
+ @mcp.tool()
37
+ def classify_domain(data: Any, categories: List[str] = ["Tech", "Finance", "Legal"], provider: str = "ollama") -> List[Any]:
38
+ """Classify data domain using LLM."""
39
+ dataset = mcp.get_data(data)
40
+ for item in dataset:
41
+ if "text" in item:
42
+ prompt = f"Classify this text into one of {categories}. Return ONLY the category name: {item['text']}"
43
+ item["domain"] = _call_llm(prompt, provider=provider).strip()
44
+ return dataset
45
+
46
+ if __name__ == "__main__":
47
+ mcp.run()
@@ -0,0 +1,17 @@
1
+ # Technical Parameters for Nemo Curator Server
2
+ # These are default technical values. Override them in your pipeline YAML.
3
+
4
+ normalize:
5
+ normalization: "NFC"
6
+ text_column: "text"
7
+ cleanup_patterns: []
8
+
9
+ quality_filter:
10
+ min_words: 50
11
+ max_non_alpha_ratio: 0.25
12
+ text_column: "text"
13
+
14
+ deduplicate:
15
+ text_column: "text"
16
+ method: "exact"
17
+ threshold: 0.85