wiba 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wiba might be problematic. Click here for more details.
- wiba/__init__.py +909 -0
- wiba-0.1.0.dist-info/LICENSE +21 -0
- wiba-0.1.0.dist-info/METADATA +157 -0
- wiba-0.1.0.dist-info/RECORD +6 -0
- wiba-0.1.0.dist-info/WHEEL +5 -0
- wiba-0.1.0.dist-info/top_level.txt +1 -0
wiba/__init__.py
ADDED
|
@@ -0,0 +1,909 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
from requests.adapters import HTTPAdapter, Retry
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import List, Optional, Dict, Any, Generic, TypeVar, Union
|
|
5
|
+
import structlog
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
import uuid
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import numpy as np
|
|
11
|
+
from io import StringIO
|
|
12
|
+
import threading
|
|
13
|
+
|
|
14
|
+
T = TypeVar('T')
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class ClientConfig:
|
|
18
|
+
"""Configuration for the WIBA client"""
|
|
19
|
+
environment: str = "production"
|
|
20
|
+
log_level: str = "INFO"
|
|
21
|
+
api_token: Optional[str] = None
|
|
22
|
+
api_url: str = "https://wiba.dev"
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class ClientStatistics:
|
|
26
|
+
"""Statistics for API usage"""
|
|
27
|
+
total_requests: int = 0
|
|
28
|
+
method_calls: Dict[str, int] = field(default_factory=lambda: {
|
|
29
|
+
'detect': 0,
|
|
30
|
+
'extract': 0,
|
|
31
|
+
'stance': 0,
|
|
32
|
+
'discover_arguments': 0
|
|
33
|
+
})
|
|
34
|
+
last_request_timestamp: Optional[datetime] = None
|
|
35
|
+
total_texts_processed: int = 0
|
|
36
|
+
errors_encountered: int = 0
|
|
37
|
+
|
|
38
|
+
class WIBAError(Exception):
|
|
39
|
+
"""Base exception for WIBA client errors"""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
class ValidationError(WIBAError):
|
|
43
|
+
"""Raised when input validation fails"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ResponseMetadata:
|
|
48
|
+
"""Metadata for API responses"""
|
|
49
|
+
request_id: str
|
|
50
|
+
timestamp: datetime = field(default_factory=datetime.utcnow)
|
|
51
|
+
processing_time: float = 0.0
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class WIBAResponse(Generic[T]):
|
|
55
|
+
"""Generic response wrapper for all WIBA API responses"""
|
|
56
|
+
data: T
|
|
57
|
+
metadata: ResponseMetadata
|
|
58
|
+
status: str = "success"
|
|
59
|
+
errors: List[Dict[str, Any]] = field(default_factory=list)
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class ArgumentDetectionResult:
|
|
63
|
+
"""Result of argument detection for a single text"""
|
|
64
|
+
text: str
|
|
65
|
+
argument_prediction: str # "Argument" or "NoArgument"
|
|
66
|
+
confidence_score: float
|
|
67
|
+
argument_components: Optional[Dict[str, Any]] = None
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class TopicExtractionResult:
|
|
71
|
+
"""Result of topic extraction for a single text"""
|
|
72
|
+
text: str
|
|
73
|
+
topics: List[str]
|
|
74
|
+
topic_metadata: Optional[Dict[str, Any]] = None
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class StanceAnalysisResult:
|
|
78
|
+
"""Result of stance analysis for a text-topic pair"""
|
|
79
|
+
text: str
|
|
80
|
+
topic: str
|
|
81
|
+
stance: str
|
|
82
|
+
supporting_evidence: Optional[List[str]] = None
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class SegmentResult:
|
|
86
|
+
"""Result of text segmentation"""
|
|
87
|
+
original_id: int
|
|
88
|
+
text_segment: str
|
|
89
|
+
start_index: int
|
|
90
|
+
end_index: int
|
|
91
|
+
text: str
|
|
92
|
+
processed_text: str
|
|
93
|
+
parent_id: Optional[int] = None
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class CalculatedSegmentResult:
|
|
97
|
+
"""Result of segment calculation"""
|
|
98
|
+
id: int
|
|
99
|
+
text: str
|
|
100
|
+
processed_text: str
|
|
101
|
+
text_segment: str
|
|
102
|
+
start_index: int
|
|
103
|
+
end_index: int
|
|
104
|
+
argument_prediction: str # "Argument" or "NoArgument"
|
|
105
|
+
argument_confidence: float # Confidence score for argument prediction
|
|
106
|
+
original_id: int
|
|
107
|
+
parent_id: Optional[int] = None
|
|
108
|
+
|
|
109
|
+
@dataclass
|
|
110
|
+
class ArgumentSegmentResult:
|
|
111
|
+
"""Result of argument discovery in text segments"""
|
|
112
|
+
id: int
|
|
113
|
+
text: str # Original full text
|
|
114
|
+
text_segment: str # The segment text
|
|
115
|
+
start_index: int # Start index in sentences
|
|
116
|
+
end_index: int # End index in sentences
|
|
117
|
+
argument_prediction: str # "Argument" or "NoArgument"
|
|
118
|
+
argument_confidence: float # Confidence score for argument prediction
|
|
119
|
+
overlapping_segments: List[str] # IDs of overlapping segments
|
|
120
|
+
processed_text: str # Preprocessed text segment
|
|
121
|
+
|
|
122
|
+
class ResponseFactory:
|
|
123
|
+
"""Factory for creating response objects from raw API responses"""
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def create_detection_response(raw_response: Dict[str, Any], input_text: str) -> WIBAResponse[List[ArgumentDetectionResult]]:
|
|
127
|
+
metadata = ResponseMetadata(
|
|
128
|
+
request_id=str(uuid.uuid4()),
|
|
129
|
+
processing_time=0.0
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# API returns a list of dictionaries with argument_prediction and argument_confidence
|
|
133
|
+
result = ArgumentDetectionResult(
|
|
134
|
+
text=input_text,
|
|
135
|
+
argument_prediction=raw_response[0]['argument_prediction'],
|
|
136
|
+
confidence_score=raw_response[0]['argument_confidence'],
|
|
137
|
+
argument_components=None
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return WIBAResponse(data=[result], metadata=metadata)
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def create_extraction_response(raw_response: Dict[str, Any], input_text: str) -> WIBAResponse[List[TopicExtractionResult]]:
|
|
144
|
+
metadata = ResponseMetadata(
|
|
145
|
+
request_id=str(uuid.uuid4()),
|
|
146
|
+
processing_time=0.0
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# API returns a list of dictionaries with extracted_topic
|
|
150
|
+
topic = raw_response[0]['extracted_topic']
|
|
151
|
+
standardized_topic = WIBAClient.TOPIC_VALUES.get(topic, topic)
|
|
152
|
+
|
|
153
|
+
result = TopicExtractionResult(
|
|
154
|
+
text=input_text,
|
|
155
|
+
topics=[standardized_topic] if standardized_topic != 'NoTopic' else [],
|
|
156
|
+
topic_metadata=None
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return WIBAResponse(data=[result], metadata=metadata)
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def create_stance_response(raw_response: Dict[str, Any], input_text: str, input_topic: str) -> WIBAResponse[List[StanceAnalysisResult]]:
|
|
163
|
+
metadata = ResponseMetadata(
|
|
164
|
+
request_id=str(uuid.uuid4()),
|
|
165
|
+
processing_time=0.0
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Use the class-level stance mapping
|
|
169
|
+
stance_text = WIBAClient.STANCE_MAP.get(raw_response[0]['stance_prediction'], raw_response[0]['stance_prediction'])
|
|
170
|
+
|
|
171
|
+
result = StanceAnalysisResult(
|
|
172
|
+
text=input_text,
|
|
173
|
+
topic=input_topic,
|
|
174
|
+
stance=stance_text,
|
|
175
|
+
supporting_evidence=None
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return WIBAResponse(data=[result], metadata=metadata)
|
|
179
|
+
|
|
180
|
+
class WIBA:
|
|
181
|
+
"""Client for interacting with the WIBA API"""
|
|
182
|
+
|
|
183
|
+
# Add stance mapping at class level
|
|
184
|
+
STANCE_MAP = {
|
|
185
|
+
'Argument in Favor': 'Favor',
|
|
186
|
+
'Argument Against': 'Against',
|
|
187
|
+
'No Argument': 'NoArgument'
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
# Add standardized values as class constants
|
|
191
|
+
ARGUMENT_VALUES = {
|
|
192
|
+
'argument': 'Argument',
|
|
193
|
+
'non-argument': 'NoArgument',
|
|
194
|
+
'non_argument': 'NoArgument',
|
|
195
|
+
'no-argument': 'NoArgument',
|
|
196
|
+
'noargument': 'NoArgument'
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
TOPIC_VALUES = {
|
|
200
|
+
'No Topic': 'NoTopic',
|
|
201
|
+
'no topic': 'NoTopic',
|
|
202
|
+
'no-topic': 'NoTopic',
|
|
203
|
+
'notopic': 'NoTopic'
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
def __init__(self, api_token: Optional[str] = None, config: Optional[ClientConfig] = None, pool_connections: int = 100, pool_maxsize: int = 100):
|
|
207
|
+
"""Initialize the WIBA client.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
api_token: API token for authentication
|
|
211
|
+
config: Optional client configuration
|
|
212
|
+
pool_connections: Number of urllib3 connection pools to cache
|
|
213
|
+
pool_maxsize: Maximum number of connections to save in the pool
|
|
214
|
+
"""
|
|
215
|
+
self.config = config or ClientConfig()
|
|
216
|
+
|
|
217
|
+
# Set API token from either direct argument or config
|
|
218
|
+
self.api_token = api_token or self.config.api_token
|
|
219
|
+
if not self.api_token:
|
|
220
|
+
raise ValidationError("API token is required. Provide it either through api_token parameter or ClientConfig.")
|
|
221
|
+
|
|
222
|
+
# Initialize statistics
|
|
223
|
+
self.statistics = ClientStatistics()
|
|
224
|
+
|
|
225
|
+
# Set up structured logging
|
|
226
|
+
self.logger = structlog.get_logger(
|
|
227
|
+
"wiba",
|
|
228
|
+
env=self.config.environment
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Initialize session with connection pooling and retry strategy
|
|
232
|
+
self.session = self._create_session(pool_connections, pool_maxsize)
|
|
233
|
+
|
|
234
|
+
# Thread-local storage for request-specific data
|
|
235
|
+
self._thread_local = threading.local()
|
|
236
|
+
|
|
237
|
+
def _create_session(self, pool_connections: int, pool_maxsize: int) -> requests.Session:
|
|
238
|
+
"""Create a new session with connection pooling and retry strategy."""
|
|
239
|
+
session = requests.Session()
|
|
240
|
+
|
|
241
|
+
# Configure retry strategy
|
|
242
|
+
retries = Retry(
|
|
243
|
+
total=5,
|
|
244
|
+
backoff_factor=0.1,
|
|
245
|
+
status_forcelist=[500, 502, 503, 504, 429], # Include rate limiting
|
|
246
|
+
allowed_methods=["GET", "POST"] # Allow retries on both GET and POST
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Configure connection pooling
|
|
250
|
+
adapter = HTTPAdapter(
|
|
251
|
+
pool_connections=pool_connections, # Number of urllib3 connection pools to cache
|
|
252
|
+
pool_maxsize=pool_maxsize, # Maximum number of connections to save in the pool
|
|
253
|
+
max_retries=retries,
|
|
254
|
+
pool_block=False # Don't block when pool is full, raise error instead
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Mount adapter for both HTTP and HTTPS
|
|
258
|
+
session.mount('http://', adapter)
|
|
259
|
+
session.mount('https://', adapter)
|
|
260
|
+
|
|
261
|
+
return session
|
|
262
|
+
|
|
263
|
+
def __enter__(self):
|
|
264
|
+
"""Context manager entry."""
|
|
265
|
+
return self
|
|
266
|
+
|
|
267
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
268
|
+
"""Context manager exit with proper cleanup."""
|
|
269
|
+
self.close()
|
|
270
|
+
|
|
271
|
+
def close(self):
|
|
272
|
+
"""Close the client and cleanup resources."""
|
|
273
|
+
if hasattr(self, 'session'):
|
|
274
|
+
self.session.close()
|
|
275
|
+
|
|
276
|
+
def _get_request_id(self) -> str:
|
|
277
|
+
"""Get a thread-local request ID."""
|
|
278
|
+
if not hasattr(self._thread_local, 'request_id'):
|
|
279
|
+
self._thread_local.request_id = str(uuid.uuid4())
|
|
280
|
+
return self._thread_local.request_id
|
|
281
|
+
|
|
282
|
+
def _update_statistics(self, method_name: str, num_texts: int = 1, error: bool = False) -> None:
|
|
283
|
+
"""Update usage statistics.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
method_name: Name of the method being called
|
|
287
|
+
num_texts: Number of texts being processed
|
|
288
|
+
error: Whether an error occurred
|
|
289
|
+
"""
|
|
290
|
+
self.statistics.total_requests += 1
|
|
291
|
+
self.statistics.method_calls[method_name] = self.statistics.method_calls.get(method_name, 0) + 1
|
|
292
|
+
self.statistics.last_request_timestamp = datetime.utcnow()
|
|
293
|
+
self.statistics.total_texts_processed += num_texts
|
|
294
|
+
if error:
|
|
295
|
+
self.statistics.errors_encountered += 1
|
|
296
|
+
|
|
297
|
+
def get_statistics(self) -> Dict[str, Any]:
|
|
298
|
+
"""Get current usage statistics.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Dictionary containing usage statistics in the server API format
|
|
302
|
+
"""
|
|
303
|
+
# Calculate total API calls
|
|
304
|
+
total_api_calls = sum(self.statistics.method_calls.values())
|
|
305
|
+
|
|
306
|
+
# Calculate method percentages
|
|
307
|
+
method_breakdown = {}
|
|
308
|
+
for method, count in self.statistics.method_calls.items():
|
|
309
|
+
method_breakdown[method] = {
|
|
310
|
+
'count': count,
|
|
311
|
+
'percentage': round((count / total_api_calls * 100) if total_api_calls > 0 else 0, 1)
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
# Ensure all methods have entries
|
|
315
|
+
for method in ['detect', 'extract', 'stance', 'discover_arguments']:
|
|
316
|
+
if method not in method_breakdown:
|
|
317
|
+
method_breakdown[method] = {'count': 0, 'percentage': 0}
|
|
318
|
+
|
|
319
|
+
return {
|
|
320
|
+
'overview': {
|
|
321
|
+
'total_api_calls': total_api_calls,
|
|
322
|
+
'total_texts_processed': self.statistics.total_texts_processed,
|
|
323
|
+
'last_request': self.statistics.last_request_timestamp.isoformat() if self.statistics.last_request_timestamp else None,
|
|
324
|
+
'errors_encountered': self.statistics.errors_encountered
|
|
325
|
+
},
|
|
326
|
+
'method_breakdown': method_breakdown
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
def _make_request(
|
|
330
|
+
self,
|
|
331
|
+
method: str,
|
|
332
|
+
endpoint: str,
|
|
333
|
+
data: Optional[Dict[str, Any]] = None,
|
|
334
|
+
headers: Optional[Dict[str, Any]] = None
|
|
335
|
+
) -> Dict[str, Any]:
|
|
336
|
+
"""Make an HTTP request to the API."""
|
|
337
|
+
url = self.config.api_url + endpoint
|
|
338
|
+
|
|
339
|
+
# Extract method name from endpoint, removing 'api' prefix
|
|
340
|
+
method_name = endpoint.split('/')[-1]
|
|
341
|
+
|
|
342
|
+
# Add request ID and authentication for tracking
|
|
343
|
+
request_id = self._get_request_id()
|
|
344
|
+
request_headers = {
|
|
345
|
+
"X-Request-ID": request_id,
|
|
346
|
+
"X-Requested-With": "XMLHttpRequest",
|
|
347
|
+
"Content-Type": "application/json",
|
|
348
|
+
"X-API-Token": self.api_token
|
|
349
|
+
}
|
|
350
|
+
if headers:
|
|
351
|
+
request_headers.update(headers)
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
# Only transform non-segment requests
|
|
355
|
+
if data and endpoint not in ['/api/create_segments', '/api/calculate_segments', '/api/discover_arguments']:
|
|
356
|
+
# Convert single text to list format for the API
|
|
357
|
+
if 'text' in data:
|
|
358
|
+
json_data = {'texts': [data['text']]}
|
|
359
|
+
elif 'texts' in data and isinstance(data['texts'], str):
|
|
360
|
+
json_data = {'texts': [data['texts']]}
|
|
361
|
+
else:
|
|
362
|
+
json_data = data
|
|
363
|
+
else:
|
|
364
|
+
json_data = data
|
|
365
|
+
|
|
366
|
+
response = self.session.request(
|
|
367
|
+
method=method,
|
|
368
|
+
url=url,
|
|
369
|
+
json=json_data,
|
|
370
|
+
headers=request_headers,
|
|
371
|
+
timeout=(5.0, 30.0) # Connect timeout, Read timeout
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Handle all error cases
|
|
375
|
+
if not response.ok:
|
|
376
|
+
self._update_statistics(method_name, error=True)
|
|
377
|
+
if response.status_code == 401:
|
|
378
|
+
raise ValidationError("Invalid API token")
|
|
379
|
+
elif response.status_code == 403:
|
|
380
|
+
raise ValidationError("API token does not have sufficient permissions")
|
|
381
|
+
elif response.status_code == 400:
|
|
382
|
+
raise ValidationError(f"Bad request: {response.text}")
|
|
383
|
+
response.raise_for_status()
|
|
384
|
+
|
|
385
|
+
# Update statistics on successful request
|
|
386
|
+
num_texts = len(json_data.get('texts', [])) if isinstance(json_data, dict) and 'texts' in json_data else 1
|
|
387
|
+
self._update_statistics(method_name, num_texts=num_texts)
|
|
388
|
+
|
|
389
|
+
return response.json()
|
|
390
|
+
|
|
391
|
+
except requests.exceptions.RequestException as e:
|
|
392
|
+
self._update_statistics(method_name, error=True)
|
|
393
|
+
if isinstance(e, requests.exceptions.HTTPError) and e.response is not None:
|
|
394
|
+
if e.response.status_code in (401, 403):
|
|
395
|
+
raise ValidationError(f"Authentication failed: {str(e)}")
|
|
396
|
+
raise WIBAError(f"Request failed: {str(e)}")
|
|
397
|
+
|
|
398
|
+
def _validate_dataframe(self, df: pd.DataFrame, required_columns: List[str]) -> None:
|
|
399
|
+
"""Validate DataFrame has required columns and non-empty data."""
|
|
400
|
+
if not isinstance(df, pd.DataFrame):
|
|
401
|
+
raise ValidationError("Input must be a pandas DataFrame")
|
|
402
|
+
|
|
403
|
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
|
404
|
+
if missing_columns:
|
|
405
|
+
raise ValidationError(f"DataFrame missing required columns: {missing_columns}")
|
|
406
|
+
|
|
407
|
+
if df.empty:
|
|
408
|
+
raise ValidationError("DataFrame is empty")
|
|
409
|
+
|
|
410
|
+
# Check for null values in required columns
|
|
411
|
+
null_counts = df[required_columns].isnull().sum()
|
|
412
|
+
if null_counts.any():
|
|
413
|
+
null_cols = null_counts[null_counts > 0].index.tolist()
|
|
414
|
+
raise ValidationError(f"Null values found in columns: {null_cols}")
|
|
415
|
+
|
|
416
|
+
def detect(self, texts: Union[str, List[str], pd.DataFrame], text_column: str = 'text', batch_size: int = 100, show_progress: bool = True) -> Union[ArgumentDetectionResult, List[ArgumentDetectionResult], pd.DataFrame]:
|
|
417
|
+
"""
|
|
418
|
+
Detect arguments in text(s).
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
texts: Input text(s) - can be a single string, list of strings, or DataFrame
|
|
422
|
+
text_column: Column name containing text if input is DataFrame
|
|
423
|
+
batch_size: Number of texts to process in each batch for list/DataFrame inputs
|
|
424
|
+
show_progress: Whether to show progress bar for batch processing
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
Single result, list of results, or DataFrame depending on input type
|
|
428
|
+
"""
|
|
429
|
+
try:
|
|
430
|
+
# Handle DataFrame input
|
|
431
|
+
if isinstance(texts, pd.DataFrame):
|
|
432
|
+
self._validate_dataframe(texts, [text_column])
|
|
433
|
+
texts_list = texts.copy()
|
|
434
|
+
texts_to_process = texts_list[text_column].tolist()
|
|
435
|
+
|
|
436
|
+
results = []
|
|
437
|
+
with tqdm(total=len(texts_to_process), desc="Detecting arguments", disable=not show_progress) as pbar:
|
|
438
|
+
for i in range(0, len(texts_to_process), batch_size):
|
|
439
|
+
batch = texts_to_process[i:i + batch_size]
|
|
440
|
+
response = self._make_request("POST", "/api/detect", {"texts": batch})
|
|
441
|
+
|
|
442
|
+
for text, result in zip(batch, response):
|
|
443
|
+
detection_result = ArgumentDetectionResult(
|
|
444
|
+
text=text,
|
|
445
|
+
argument_prediction=result['argument_prediction'],
|
|
446
|
+
confidence_score=result['argument_confidence'],
|
|
447
|
+
argument_components=None
|
|
448
|
+
)
|
|
449
|
+
results.append(detection_result)
|
|
450
|
+
pbar.update(len(batch))
|
|
451
|
+
|
|
452
|
+
# Add results to DataFrame
|
|
453
|
+
texts_list['argument_prediction'] = [r.argument_prediction for r in results]
|
|
454
|
+
texts_list['argument_confidence'] = [r.confidence_score for r in results]
|
|
455
|
+
return texts_list
|
|
456
|
+
|
|
457
|
+
# Handle list input
|
|
458
|
+
elif isinstance(texts, list):
|
|
459
|
+
results = []
|
|
460
|
+
with tqdm(total=len(texts), desc="Detecting arguments", disable=not show_progress) as pbar:
|
|
461
|
+
for i in range(0, len(texts), batch_size):
|
|
462
|
+
batch = texts[i:i + batch_size]
|
|
463
|
+
response = self._make_request("POST", "/api/detect", {"texts": batch})
|
|
464
|
+
|
|
465
|
+
for text, result in zip(batch, response):
|
|
466
|
+
detection_result = ArgumentDetectionResult(
|
|
467
|
+
text=text,
|
|
468
|
+
argument_prediction=result['argument_prediction'],
|
|
469
|
+
confidence_score=result['argument_confidence'],
|
|
470
|
+
argument_components=None
|
|
471
|
+
)
|
|
472
|
+
results.append(detection_result)
|
|
473
|
+
pbar.update(len(batch))
|
|
474
|
+
return results
|
|
475
|
+
|
|
476
|
+
# Handle single string input
|
|
477
|
+
elif isinstance(texts, str):
|
|
478
|
+
response = self._make_request("POST", "/api/detect", {"texts": [texts]})
|
|
479
|
+
return ArgumentDetectionResult(
|
|
480
|
+
text=texts,
|
|
481
|
+
argument_prediction=response[0]['argument_prediction'],
|
|
482
|
+
confidence_score=response[0]['argument_confidence'],
|
|
483
|
+
argument_components=None
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
else:
|
|
487
|
+
raise ValidationError("Input must be a string, list of strings, or DataFrame")
|
|
488
|
+
|
|
489
|
+
except Exception as e:
|
|
490
|
+
raise
|
|
491
|
+
|
|
492
|
+
def extract(self, texts: Union[str, List[str], pd.DataFrame], text_column: str = 'text', batch_size: int = 100, show_progress: bool = True) -> Union[TopicExtractionResult, List[TopicExtractionResult], pd.DataFrame]:
|
|
493
|
+
"""
|
|
494
|
+
Extract topics from text(s).
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
texts: Input text(s) - can be a single string, list of strings, or DataFrame
|
|
498
|
+
text_column: Column name containing text if input is DataFrame
|
|
499
|
+
batch_size: Number of texts to process in each batch for list/DataFrame inputs
|
|
500
|
+
show_progress: Whether to show progress bar for batch processing
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
Single result, list of results, or DataFrame depending on input type
|
|
504
|
+
"""
|
|
505
|
+
try:
|
|
506
|
+
# Handle DataFrame input
|
|
507
|
+
if isinstance(texts, pd.DataFrame):
|
|
508
|
+
self._validate_dataframe(texts, [text_column])
|
|
509
|
+
texts_list = texts.copy()
|
|
510
|
+
texts_to_process = texts_list[text_column].tolist()
|
|
511
|
+
|
|
512
|
+
results = []
|
|
513
|
+
with tqdm(total=len(texts_to_process), desc="Extracting topics", disable=not show_progress) as pbar:
|
|
514
|
+
for i in range(0, len(texts_to_process), batch_size):
|
|
515
|
+
batch = texts_to_process[i:i + batch_size]
|
|
516
|
+
response = self._make_request("POST", "/api/extract", {"texts": batch})
|
|
517
|
+
|
|
518
|
+
for text, result in zip(batch, response):
|
|
519
|
+
extraction_result = TopicExtractionResult(
|
|
520
|
+
text=text,
|
|
521
|
+
topics=[result['extracted_topic']] if result['extracted_topic'] != 'No Topic' else [],
|
|
522
|
+
topic_metadata=None
|
|
523
|
+
)
|
|
524
|
+
results.append(extraction_result)
|
|
525
|
+
pbar.update(len(batch))
|
|
526
|
+
|
|
527
|
+
# Add results to DataFrame
|
|
528
|
+
texts_list['extracted_topics'] = [','.join(r.topics) if r.topics else 'No Topic' for r in results]
|
|
529
|
+
return texts_list
|
|
530
|
+
|
|
531
|
+
# Handle list input
|
|
532
|
+
elif isinstance(texts, list):
|
|
533
|
+
results = []
|
|
534
|
+
with tqdm(total=len(texts), desc="Extracting topics", disable=not show_progress) as pbar:
|
|
535
|
+
for i in range(0, len(texts), batch_size):
|
|
536
|
+
batch = texts[i:i + batch_size]
|
|
537
|
+
response = self._make_request("POST", "/api/extract", {"texts": batch})
|
|
538
|
+
|
|
539
|
+
for text, result in zip(batch, response):
|
|
540
|
+
extraction_result = TopicExtractionResult(
|
|
541
|
+
text=text,
|
|
542
|
+
topics=[result['extracted_topic']] if result['extracted_topic'] != 'No Topic' else [],
|
|
543
|
+
topic_metadata=None
|
|
544
|
+
)
|
|
545
|
+
results.append(extraction_result)
|
|
546
|
+
pbar.update(len(batch))
|
|
547
|
+
return results
|
|
548
|
+
|
|
549
|
+
# Handle single string input
|
|
550
|
+
elif isinstance(texts, str):
|
|
551
|
+
response = self._make_request("POST", "/api/extract", {"text": texts})
|
|
552
|
+
return TopicExtractionResult(
|
|
553
|
+
text=texts,
|
|
554
|
+
topics=[response[0]['extracted_topic']] if response[0]['extracted_topic'] != 'No Topic' else [],
|
|
555
|
+
topic_metadata=None
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
else:
|
|
559
|
+
raise ValidationError("Input must be a string, list of strings, or DataFrame")
|
|
560
|
+
|
|
561
|
+
except Exception as e:
|
|
562
|
+
raise
|
|
563
|
+
|
|
564
|
+
def stance(self, texts: Union[str, List[str], pd.DataFrame], topics: Union[str, List[str], None] = None,
|
|
565
|
+
text_column: str = 'text', topic_column: str = 'topic', batch_size: int = 100,
|
|
566
|
+
show_progress: bool = True) -> Union[StanceAnalysisResult, List[StanceAnalysisResult], pd.DataFrame]:
|
|
567
|
+
"""
|
|
568
|
+
Analyze stance of text(s) in relation to topic(s).
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
texts: Input text(s) - can be a single string, list of strings, or DataFrame
|
|
572
|
+
topics: Topic(s) - required unless input is DataFrame with topic_column
|
|
573
|
+
text_column: Column name containing text if input is DataFrame
|
|
574
|
+
topic_column: Column name containing topics if input is DataFrame
|
|
575
|
+
batch_size: Number of texts to process in each batch for list/DataFrame inputs
|
|
576
|
+
show_progress: Whether to show progress bar for batch processing
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
Single result, list of results, or DataFrame depending on input type
|
|
580
|
+
"""
|
|
581
|
+
try:
|
|
582
|
+
# Handle DataFrame input
|
|
583
|
+
if isinstance(texts, pd.DataFrame):
|
|
584
|
+
self._validate_dataframe(texts, [text_column, topic_column])
|
|
585
|
+
texts_list = texts.copy()
|
|
586
|
+
texts_to_process = texts_list[text_column].tolist()
|
|
587
|
+
topics_to_process = texts_list[topic_column].tolist()
|
|
588
|
+
|
|
589
|
+
results = []
|
|
590
|
+
with tqdm(total=len(texts_to_process), desc="Analyzing stances", disable=not show_progress) as pbar:
|
|
591
|
+
for i in range(0, len(texts_to_process), batch_size):
|
|
592
|
+
batch_texts = texts_to_process[i:i + batch_size]
|
|
593
|
+
batch_topics = topics_to_process[i:i + batch_size]
|
|
594
|
+
|
|
595
|
+
response = self._make_request(
|
|
596
|
+
"POST",
|
|
597
|
+
"/api/stance",
|
|
598
|
+
{
|
|
599
|
+
"texts": batch_texts,
|
|
600
|
+
"topics": batch_topics
|
|
601
|
+
}
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
for text, topic, result in zip(batch_texts, batch_topics, response):
|
|
605
|
+
stance_text = self.STANCE_MAP.get(result['stance_prediction'], result['stance_prediction'])
|
|
606
|
+
stance_result = StanceAnalysisResult(
|
|
607
|
+
text=text,
|
|
608
|
+
topic=topic,
|
|
609
|
+
stance=stance_text,
|
|
610
|
+
supporting_evidence=None
|
|
611
|
+
)
|
|
612
|
+
results.append(stance_result)
|
|
613
|
+
pbar.update(len(batch_texts))
|
|
614
|
+
|
|
615
|
+
# Add results to DataFrame
|
|
616
|
+
texts_list['stance'] = [r.stance for r in results]
|
|
617
|
+
return texts_list
|
|
618
|
+
|
|
619
|
+
# Handle list input
|
|
620
|
+
elif isinstance(texts, list):
|
|
621
|
+
if not topics or not isinstance(topics, list) or len(texts) != len(topics):
|
|
622
|
+
raise ValidationError("Must provide matching list of topics for list of texts")
|
|
623
|
+
|
|
624
|
+
results = []
|
|
625
|
+
with tqdm(total=len(texts), desc="Analyzing stances", disable=not show_progress) as pbar:
|
|
626
|
+
for i in range(0, len(texts), batch_size):
|
|
627
|
+
batch_texts = texts[i:i + batch_size]
|
|
628
|
+
batch_topics = topics[i:i + batch_size]
|
|
629
|
+
|
|
630
|
+
response = self._make_request(
|
|
631
|
+
"POST",
|
|
632
|
+
"/api/stance",
|
|
633
|
+
{
|
|
634
|
+
"texts": batch_texts,
|
|
635
|
+
"topics": batch_topics
|
|
636
|
+
}
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
for text, topic, result in zip(batch_texts, batch_topics, response):
|
|
640
|
+
stance_text = self.STANCE_MAP.get(result['stance_prediction'], result['stance_prediction'])
|
|
641
|
+
stance_result = StanceAnalysisResult(
|
|
642
|
+
text=text,
|
|
643
|
+
topic=topic,
|
|
644
|
+
stance=stance_text,
|
|
645
|
+
supporting_evidence=None
|
|
646
|
+
)
|
|
647
|
+
results.append(stance_result)
|
|
648
|
+
pbar.update(len(batch_texts))
|
|
649
|
+
return results
|
|
650
|
+
|
|
651
|
+
# Handle single string input
|
|
652
|
+
elif isinstance(texts, str):
|
|
653
|
+
if not topics or not isinstance(topics, str):
|
|
654
|
+
raise ValidationError("Must provide a topic string for single text input")
|
|
655
|
+
|
|
656
|
+
response = self._make_request(
|
|
657
|
+
"POST",
|
|
658
|
+
"/api/stance",
|
|
659
|
+
{
|
|
660
|
+
"texts": [texts],
|
|
661
|
+
"topics": [topics]
|
|
662
|
+
}
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
stance_text = self.STANCE_MAP.get(response[0]['stance_prediction'], response[0]['stance_prediction'])
|
|
666
|
+
return StanceAnalysisResult(
|
|
667
|
+
text=texts,
|
|
668
|
+
topic=topics,
|
|
669
|
+
stance=stance_text,
|
|
670
|
+
supporting_evidence=None
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
else:
|
|
674
|
+
raise ValidationError("Input must be a string, list of strings, or DataFrame")
|
|
675
|
+
|
|
676
|
+
except Exception as e:
|
|
677
|
+
raise
|
|
678
|
+
|
|
679
|
+
def analyze_stance(self, texts: Union[str, List[str], pd.DataFrame], topics: Union[str, List[str], None] = None,
|
|
680
|
+
text_column: str = 'text', topic_column: str = 'topic', batch_size: int = 100,
|
|
681
|
+
show_progress: bool = True) -> Union[StanceAnalysisResult, List[StanceAnalysisResult], pd.DataFrame]:
|
|
682
|
+
"""Deprecated: Use stance() instead"""
|
|
683
|
+
return self.stance(texts, topics, text_column=text_column, topic_column=topic_column,
|
|
684
|
+
batch_size=batch_size, show_progress=show_progress)
|
|
685
|
+
|
|
686
|
+
def process_csv(self, csv_data: Union[str, StringIO], text_column: str = 'text', topic_column: Optional[str] = None,
|
|
687
|
+
detect: bool = True, extract: bool = True, stance: bool = False, batch_size: int = 100) -> pd.DataFrame:
|
|
688
|
+
"""Process a CSV file through multiple analyses.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
csv_data: CSV string or StringIO object
|
|
692
|
+
text_column: Name of column containing text to analyze
|
|
693
|
+
topic_column: Name of column containing topics (required for stance analysis)
|
|
694
|
+
detect: Whether to perform argument detection
|
|
695
|
+
extract: Whether to perform topic extraction
|
|
696
|
+
stance: Whether to perform stance analysis
|
|
697
|
+
batch_size: Number of texts to process in each batch
|
|
698
|
+
|
|
699
|
+
Returns:
|
|
700
|
+
DataFrame with results from all requested analyses
|
|
701
|
+
"""
|
|
702
|
+
try:
|
|
703
|
+
# Read CSV
|
|
704
|
+
if isinstance(csv_data, str):
|
|
705
|
+
df = pd.read_csv(StringIO(csv_data))
|
|
706
|
+
else:
|
|
707
|
+
df = pd.read_csv(csv_data)
|
|
708
|
+
|
|
709
|
+
self._validate_dataframe(df, [text_column])
|
|
710
|
+
|
|
711
|
+
# Perform requested analyses
|
|
712
|
+
if detect:
|
|
713
|
+
df = self.process_dataframe_detect(df, text_column, batch_size)
|
|
714
|
+
|
|
715
|
+
if extract:
|
|
716
|
+
df = self.process_dataframe_extract(df, text_column, batch_size)
|
|
717
|
+
|
|
718
|
+
if stance:
|
|
719
|
+
if not topic_column or topic_column not in df.columns:
|
|
720
|
+
raise ValidationError("Topic column required for stance analysis")
|
|
721
|
+
df = self.process_dataframe_stance(df, text_column, topic_column, batch_size)
|
|
722
|
+
|
|
723
|
+
return df
|
|
724
|
+
|
|
725
|
+
except Exception as e:
|
|
726
|
+
self.logger.error("CSV processing failed", error=str(e))
|
|
727
|
+
raise
|
|
728
|
+
|
|
729
|
+
def save_results(self, df: pd.DataFrame, output_path: str, format: str = 'csv') -> None:
|
|
730
|
+
"""Save results DataFrame to file.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
df: DataFrame to save
|
|
734
|
+
output_path: Path to save file
|
|
735
|
+
format: Output format ('csv' or 'json')
|
|
736
|
+
"""
|
|
737
|
+
try:
|
|
738
|
+
if format.lower() == 'csv':
|
|
739
|
+
df.to_csv(output_path, index=False)
|
|
740
|
+
elif format.lower() == 'json':
|
|
741
|
+
df.to_json(output_path, orient='records', lines=True)
|
|
742
|
+
else:
|
|
743
|
+
raise ValueError(f"Unsupported output format: {format}")
|
|
744
|
+
|
|
745
|
+
except Exception as e:
|
|
746
|
+
self.logger.error("Failed to save results", error=str(e))
|
|
747
|
+
raise
|
|
748
|
+
|
|
749
|
+
def process_dataframe_detect(self, df: pd.DataFrame, text_column: str = 'text', batch_size: int = 100) -> pd.DataFrame:
|
|
750
|
+
"""Deprecated: Use detect() instead"""
|
|
751
|
+
return self.detect(df, text_column=text_column, batch_size=batch_size)
|
|
752
|
+
|
|
753
|
+
def process_dataframe_extract(self, df: pd.DataFrame, text_column: str = 'text', batch_size: int = 100) -> pd.DataFrame:
|
|
754
|
+
"""Deprecated: Use extract() instead"""
|
|
755
|
+
return self.extract(df, text_column=text_column, batch_size=batch_size)
|
|
756
|
+
|
|
757
|
+
def process_dataframe_stance(self, df: pd.DataFrame, text_column: str = 'text', topic_column: str = 'topic', batch_size: int = 100) -> pd.DataFrame:
|
|
758
|
+
"""Deprecated: Use stance() instead"""
|
|
759
|
+
return self.stance(df, text_column=text_column, topic_column=topic_column, batch_size=batch_size)
|
|
760
|
+
|
|
761
|
+
def discover_arguments(self, texts: Union[str, pd.DataFrame], text_column: str = 'text', window_size: int = 3,
|
|
762
|
+
step_size: int = 1, batch_size: int = 5, show_progress: bool = True,
|
|
763
|
+
max_text_length: int = 10000) -> pd.DataFrame:
|
|
764
|
+
"""
|
|
765
|
+
Discover arguments in text(s) by segmenting and analyzing it.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
texts: Input text(s) - can be a single string or DataFrame
|
|
769
|
+
text_column: Column name containing text if input is DataFrame
|
|
770
|
+
window_size: Size of the sliding window. Defaults to 3.
|
|
771
|
+
step_size: Step size for the sliding window. Defaults to 1.
|
|
772
|
+
batch_size: Number of texts to process in each batch for DataFrame input. Defaults to 5.
|
|
773
|
+
show_progress: Whether to show progress bar for DataFrame input. Defaults to True.
|
|
774
|
+
max_text_length: Maximum allowed length for each text. Defaults to 10000.
|
|
775
|
+
|
|
776
|
+
Returns:
|
|
777
|
+
pd.DataFrame: DataFrame containing the discovered arguments and segments
|
|
778
|
+
"""
|
|
779
|
+
try:
|
|
780
|
+
# Handle DataFrame input
|
|
781
|
+
if isinstance(texts, pd.DataFrame):
|
|
782
|
+
self._validate_dataframe(texts, [text_column])
|
|
783
|
+
texts_list = texts.copy()
|
|
784
|
+
texts_to_process = texts_list[text_column].tolist()
|
|
785
|
+
|
|
786
|
+
# Validate text lengths
|
|
787
|
+
for text in texts_to_process:
|
|
788
|
+
if len(text) > max_text_length:
|
|
789
|
+
raise ValidationError(f"Text exceeds maximum length of {max_text_length} characters")
|
|
790
|
+
|
|
791
|
+
all_results = []
|
|
792
|
+
with tqdm(total=len(texts), desc="Discovering arguments", disable=not show_progress) as pbar:
|
|
793
|
+
# Process in batches
|
|
794
|
+
for i in range(0, len(texts), batch_size):
|
|
795
|
+
batch_texts = texts.iloc[i:i + batch_size]
|
|
796
|
+
batch_results = []
|
|
797
|
+
|
|
798
|
+
# Process each text in the batch
|
|
799
|
+
for _, row in batch_texts.iterrows():
|
|
800
|
+
text = row[text_column]
|
|
801
|
+
result_df = self._discover_arguments_single(text, window_size, step_size)
|
|
802
|
+
|
|
803
|
+
# Add original row data to results
|
|
804
|
+
for col in batch_texts.columns:
|
|
805
|
+
if col != text_column:
|
|
806
|
+
result_df[col] = row[col]
|
|
807
|
+
|
|
808
|
+
batch_results.append(result_df)
|
|
809
|
+
|
|
810
|
+
# Add batch results and update progress
|
|
811
|
+
all_results.extend(batch_results)
|
|
812
|
+
pbar.update(len(batch_texts))
|
|
813
|
+
|
|
814
|
+
# Combine all results
|
|
815
|
+
return pd.concat(all_results, ignore_index=True)
|
|
816
|
+
|
|
817
|
+
# Handle single string input
|
|
818
|
+
elif isinstance(texts, str):
|
|
819
|
+
if len(texts) > max_text_length:
|
|
820
|
+
raise ValidationError(f"Text exceeds maximum length of {max_text_length} characters")
|
|
821
|
+
return self._discover_arguments_single(texts, window_size, step_size)
|
|
822
|
+
|
|
823
|
+
else:
|
|
824
|
+
raise ValidationError("Input must be a string or DataFrame")
|
|
825
|
+
|
|
826
|
+
except Exception as e:
|
|
827
|
+
raise
|
|
828
|
+
|
|
829
|
+
def _discover_arguments_single(self, text: str, window_size: int, step_size: int) -> pd.DataFrame:
|
|
830
|
+
"""Internal method to discover arguments in a single text."""
|
|
831
|
+
# Validate input
|
|
832
|
+
if not text or not isinstance(text, str):
|
|
833
|
+
raise ValidationError("Input text must be a non-empty string")
|
|
834
|
+
if window_size < 1:
|
|
835
|
+
raise ValidationError("window_size must be greater than 0")
|
|
836
|
+
if step_size < 1:
|
|
837
|
+
raise ValidationError("step_size must be greater than 0")
|
|
838
|
+
|
|
839
|
+
# Prepare request data in the format expected by the server
|
|
840
|
+
request_data = {
|
|
841
|
+
"text": text,
|
|
842
|
+
"params": {
|
|
843
|
+
"window_size": window_size,
|
|
844
|
+
"step_size": step_size,
|
|
845
|
+
"min_segment_length": 1,
|
|
846
|
+
"max_segment_length": 100,
|
|
847
|
+
"overlap": True
|
|
848
|
+
}
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
try:
|
|
852
|
+
# Make request to discover arguments
|
|
853
|
+
response = self._make_request(
|
|
854
|
+
method="POST",
|
|
855
|
+
endpoint="/api/discover_arguments",
|
|
856
|
+
data=request_data,
|
|
857
|
+
headers={
|
|
858
|
+
"Content-Type": "application/json",
|
|
859
|
+
"X-Request-ID": str(uuid.uuid4())
|
|
860
|
+
}
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
# Convert response to DataFrame
|
|
864
|
+
if isinstance(response, list):
|
|
865
|
+
result_df = pd.DataFrame(response)
|
|
866
|
+
else:
|
|
867
|
+
result_df = pd.DataFrame([response])
|
|
868
|
+
|
|
869
|
+
# Ensure all required columns are present
|
|
870
|
+
required_columns = [
|
|
871
|
+
'id', 'text', 'text_segment', 'start_index', 'end_index',
|
|
872
|
+
'argument_prediction', 'argument_confidence',
|
|
873
|
+
'overlapping_segments', 'processed_text'
|
|
874
|
+
]
|
|
875
|
+
|
|
876
|
+
for col in required_columns:
|
|
877
|
+
if col not in result_df.columns:
|
|
878
|
+
if col == 'overlapping_segments':
|
|
879
|
+
result_df[col] = result_df[col].apply(lambda x: [] if pd.isna(x) else x)
|
|
880
|
+
elif col == 'argument_prediction':
|
|
881
|
+
# If segment_type exists, use it to set argument_prediction
|
|
882
|
+
if 'segment_type' in result_df.columns:
|
|
883
|
+
result_df[col] = result_df['segment_type'].apply(
|
|
884
|
+
lambda x: self.ARGUMENT_VALUES.get(str(x).lower(), 'NoArgument')
|
|
885
|
+
)
|
|
886
|
+
else:
|
|
887
|
+
result_df[col] = 'NoArgument'
|
|
888
|
+
else:
|
|
889
|
+
result_df[col] = None
|
|
890
|
+
|
|
891
|
+
# Standardize argument prediction values using class constant
|
|
892
|
+
result_df['argument_prediction'] = result_df['argument_prediction'].apply(
|
|
893
|
+
lambda x: self.ARGUMENT_VALUES.get(str(x).lower(), 'NoArgument')
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
# Remove redundant columns if they exist
|
|
897
|
+
columns_to_drop = ['segment_type', 'is_argument']
|
|
898
|
+
result_df = result_df.drop(columns=[col for col in columns_to_drop if col in result_df.columns])
|
|
899
|
+
|
|
900
|
+
# Sort by start_index and argument_confidence
|
|
901
|
+
result_df = result_df.sort_values(
|
|
902
|
+
['start_index', 'argument_confidence'],
|
|
903
|
+
ascending=[True, False]
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
return result_df
|
|
907
|
+
|
|
908
|
+
except Exception as e:
|
|
909
|
+
raise WIBAError(f"Failed to discover arguments: {str(e)}")
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Arman Akbarian
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: wiba
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: WIBA: What Is Being Argued? A Comprehensive Approach to Argument Mining
|
|
5
|
+
Home-page: https://github.com/Armaniii/WIBA
|
|
6
|
+
Author: Arman Irani
|
|
7
|
+
Author-email: airan002@ucr.edu
|
|
8
|
+
Project-URL: Bug Tracker, https://github.com/Armaniii/WIBA/issues
|
|
9
|
+
Project-URL: Documentation, https://wiba.dev
|
|
10
|
+
Project-URL: Source Code, https://github.com/Armaniii/WIBA
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Classifier: Topic :: Text Processing :: Linguistic
|
|
21
|
+
Requires-Python: >=3.8
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
License-File: LICENSE
|
|
24
|
+
Requires-Dist: requests>=2.25.0
|
|
25
|
+
Requires-Dist: pandas>=1.2.0
|
|
26
|
+
Requires-Dist: numpy>=1.19.0
|
|
27
|
+
Requires-Dist: tqdm>=4.50.0
|
|
28
|
+
Requires-Dist: structlog>=21.1.0
|
|
29
|
+
Dynamic: author
|
|
30
|
+
Dynamic: author-email
|
|
31
|
+
Dynamic: classifier
|
|
32
|
+
Dynamic: description
|
|
33
|
+
Dynamic: description-content-type
|
|
34
|
+
Dynamic: home-page
|
|
35
|
+
Dynamic: project-url
|
|
36
|
+
Dynamic: requires-dist
|
|
37
|
+
Dynamic: requires-python
|
|
38
|
+
Dynamic: summary
|
|
39
|
+
|
|
40
|
+
# WIBA: What Is Being Argued?
|
|
41
|
+
|
|
42
|
+
WIBA is a comprehensive argument mining toolkit that helps you detect, analyze, and understand arguments in text. It provides a simple yet powerful interface to identify argumentative content, extract topics, analyze stance, and discover arguments in longer texts.
|
|
43
|
+
|
|
44
|
+
## Installation
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
pip install wiba
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## Quick Start
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
from wiba import WIBA
|
|
54
|
+
|
|
55
|
+
# Initialize with your API token
|
|
56
|
+
analyzer = WIBA(api_token="your_api_token_here")
|
|
57
|
+
|
|
58
|
+
# Example text
|
|
59
|
+
text = "Climate change is real because global temperatures are rising."
|
|
60
|
+
|
|
61
|
+
# Detect if it's an argument
|
|
62
|
+
result = analyzer.detect(text)
|
|
63
|
+
print(f"Argument detected: {result.argument_prediction}")
|
|
64
|
+
print(f"Confidence: {result.confidence_score}")
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## Features
|
|
68
|
+
|
|
69
|
+
- **Argument Detection**: Identify whether a text contains an argument
|
|
70
|
+
- **Topic Extraction**: Extract the main topic being argued about
|
|
71
|
+
- **Stance Analysis**: Determine the stance towards a specific topic
|
|
72
|
+
- **Argument Discovery**: Find argumentative segments in longer texts
|
|
73
|
+
- **Batch Processing**: Efficiently process multiple texts
|
|
74
|
+
- **DataFrame Support**: Native pandas DataFrame integration
|
|
75
|
+
|
|
76
|
+
## Documentation
|
|
77
|
+
|
|
78
|
+
For detailed documentation and examples, visit [wiba.dev](https://wiba.dev).
|
|
79
|
+
|
|
80
|
+
## Getting Started
|
|
81
|
+
|
|
82
|
+
1. Create an account at [wiba.dev](https://wiba.dev) to get your API token
|
|
83
|
+
2. Install the package: `pip install wiba`
|
|
84
|
+
3. Initialize the client with your token
|
|
85
|
+
4. Start analyzing arguments!
|
|
86
|
+
|
|
87
|
+
## Example Usage
|
|
88
|
+
|
|
89
|
+
### Detect Arguments
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
# Single text
|
|
93
|
+
result = analyzer.detect("Climate change is real because temperatures are rising.")
|
|
94
|
+
print(result.argument_prediction) # "Argument" or "NoArgument"
|
|
95
|
+
print(result.confidence_score) # Confidence score between 0 and 1
|
|
96
|
+
|
|
97
|
+
# Multiple texts
|
|
98
|
+
texts = [
|
|
99
|
+
"Climate change is real because temperatures are rising.",
|
|
100
|
+
"This is just a simple statement without any argument."
|
|
101
|
+
]
|
|
102
|
+
results = analyzer.detect(texts)
|
|
103
|
+
for r in results:
|
|
104
|
+
print(f"Text: {r.text}")
|
|
105
|
+
print(f"Prediction: {r.argument_prediction}")
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
### Extract Topics
|
|
109
|
+
|
|
110
|
+
```python
|
|
111
|
+
result = analyzer.extract("Climate change is a serious issue because it affects our environment.")
|
|
112
|
+
print(result.topics) # List of extracted topics
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
### Analyze Stance
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
text = "We must take action on climate change because the evidence is overwhelming."
|
|
119
|
+
topic = "climate change"
|
|
120
|
+
result = analyzer.stance(text, topic)
|
|
121
|
+
print(f"Stance: {result.stance}") # "Favor", "Against", or "NoArgument"
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
### Discover Arguments
|
|
125
|
+
|
|
126
|
+
```python
|
|
127
|
+
text = """Climate change is a serious issue. Global temperatures are rising at an
|
|
128
|
+
unprecedented rate. This is causing extreme weather events. However, some argue
|
|
129
|
+
that natural climate cycles are responsible."""
|
|
130
|
+
|
|
131
|
+
results_df = analyzer.discover_arguments(
|
|
132
|
+
text,
|
|
133
|
+
window_size=2, # Number of sentences per window
|
|
134
|
+
step_size=1 # Number of sentences to move window
|
|
135
|
+
)
|
|
136
|
+
print(results_df[['text_segment', 'argument_prediction', 'argument_confidence']])
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
## Citation
|
|
140
|
+
|
|
141
|
+
If you use WIBA in your research, please cite:
|
|
142
|
+
|
|
143
|
+
```bibtex
|
|
144
|
+
@misc{irani2024wibaarguedcomprehensiveapproach,
|
|
145
|
+
title={WIBA: What Is Being Argued? A Comprehensive Approach to Argument Mining},
|
|
146
|
+
author={Arman Irani and Ju Yeon Park and Kevin Esterling and Michalis Faloutsos},
|
|
147
|
+
year={2024},
|
|
148
|
+
eprint={2405.00828},
|
|
149
|
+
archivePrefix={arXiv},
|
|
150
|
+
primaryClass={cs.CL},
|
|
151
|
+
url={https://arxiv.org/abs/2405.00828},
|
|
152
|
+
}
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
## License
|
|
156
|
+
|
|
157
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
wiba/__init__.py,sha256=LY2nqW2D4BwNUim2RIx7Wa3hnInKKugsbTJb0E85KZ0,39551
|
|
2
|
+
wiba-0.1.0.dist-info/LICENSE,sha256=fPUf3LPdrEm9c7af7DvOfkWLDbWtR-D1QB34Q04JeUE,1071
|
|
3
|
+
wiba-0.1.0.dist-info/METADATA,sha256=QcV9kA71kc1kwNeui4uZ9Ix-CuWYWizYCplbuuRdChU,4878
|
|
4
|
+
wiba-0.1.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
5
|
+
wiba-0.1.0.dist-info/top_level.txt,sha256=nAllQVKrFATsq88y_OeOwPEtwucwbqkZ0jHB8OkKclI,5
|
|
6
|
+
wiba-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
wiba
|