dingo-python 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,551 @@
1
+ import re
2
+ import jieba
3
+ import langid
4
+ import textstat
5
+
6
+ from typing import List, Tuple
7
+ from hanziconv import HanziConv
8
+ from nltk.tokenize import WordPunctTokenizer
9
+
10
+
11
+ from dingo.model.model import Model
12
+ from dingo.model.rule.util import normalize, base_rps_frac_chars_in_dupe_ngrams, get_stop_words, split_paragraphs, TextSlice
13
+ from dingo.model.rule.base import BaseRule, ResModel
14
+
15
+
16
+ @Model.rule_register('QUALITY_SIGNAL_COMPLETENESS', ['default','sft','pretrain','benchmark'])
17
+ class CommonColonEnd(BaseRule):
18
+ """check whether the last char is ':'"""
19
+
20
+ @classmethod
21
+ def eval(cls, input_data: List[str]) -> ResModel:
22
+ assert len(input_data) == 1
23
+ res = ResModel()
24
+ if len(input_data[0]) <= 0:
25
+ return res
26
+ if input_data[0][-1] == ':':
27
+ res.error_status = True
28
+ res.error_reason = 'Ends with a colon.'
29
+ return res
30
+
31
+
32
+ @Model.rule_register('QUALITY_SIGNAL_EFFECTIVENESS', ['default','sft','pretrain','benchmark'])
33
+ class CommonContentNull(BaseRule):
34
+ """check whether content is null"""
35
+
36
+ @classmethod
37
+ def eval(cls, input_data: List[str]) -> ResModel:
38
+ assert len(input_data) == 1
39
+ res = ResModel()
40
+ count = len(input_data[0].strip())
41
+ if count == 0:
42
+ res.error_status = True
43
+ res.error_reason = 'Content is empty.'
44
+ return res
45
+
46
+
47
+ @Model.rule_register('QUALITY_SIGNAL_SIMILARITY', ['default','sft','pretrain','benchmark'])
48
+ class CommonDocRepeat(BaseRule):
49
+ """check whether content repeats"""
50
+
51
+ @classmethod
52
+ def eval(cls, input_data: List[str]) -> ResModel:
53
+ assert len(input_data) == 1
54
+ res = ResModel()
55
+ repeat_score = base_rps_frac_chars_in_dupe_ngrams(6, input_data[0])
56
+ if repeat_score >= 80:
57
+ res.error_status = True
58
+ res.error_reason = 'Repeatability of text is too high, with ratio: ' + str(repeat_score)
59
+ return res
60
+
61
+
62
+ @Model.rule_register('QUALITY_SIGNAL_RELEVANCE', ['default','sft','pretrain','benchmark'])
63
+ class CommonHtmlEntity(BaseRule):
64
+ """check whether content has html entity"""
65
+
66
+ @classmethod
67
+ def eval(cls, input_data: List[str]) -> ResModel:
68
+ assert len(input_data) == 1
69
+ res = ResModel()
70
+ entities = [
71
+ "nbsp",
72
+ "lt",
73
+ "gt",
74
+ "amp",
75
+ "quot",
76
+ "apos",
77
+ "hellip",
78
+ "ndash",
79
+ "mdash",
80
+ "lsquo",
81
+ "rsquo",
82
+ "ldquo",
83
+ "rdquo",
84
+ ]
85
+ full_entities_1 = [f"&{entity};" for entity in entities]
86
+ full_entities_2 = [f"&{entity};" for entity in entities]
87
+ full_entities_3 = [f"&{entity};" for entity in entities]
88
+ full_entities_4 = [f"&{entity};" for entity in entities]
89
+ full_entities = (
90
+ full_entities_1 + full_entities_2 + full_entities_3 + full_entities_4
91
+ )
92
+ # half_entity_1 = [f"{entity};" for entity in entities]
93
+ half_entity_2 = [f"&{entity}" for entity in entities]
94
+ half_entity_3 = [f"&{entity}" for entity in entities]
95
+ # half_entity_4 = [f"{entity};" for entity in entities]
96
+ half_entities = half_entity_2 + half_entity_3
97
+ # maked_entities = [f"{entity}" for entity in entities]
98
+ all_entities = full_entities + half_entities
99
+
100
+ pattern = '|'.join(all_entities)
101
+ matches = re.findall(pattern, input_data[0])
102
+ if matches:
103
+ res.error_status = True
104
+ res.error_reason = matches
105
+ return res
106
+
107
+
108
+ @Model.rule_register('QUALITY_SIGNAL_SECURITY', ['default','sft','pretrain','benchmark'])
109
+ class CommonLicenseKey(BaseRule):
110
+ """check if the content contains license key"""
111
+ pattern = "|".join(
112
+ [
113
+ "[A-Z0-9]{47}",
114
+ "[A-Z0-9]{4}-[A-Z0-9]{4}-[A-Z0-9]{4}-[A-Z0-9]{4}",
115
+ "[A-Z0-9]{4}-\d{8}-[A-Z0-9]{4}",
116
+ ]
117
+ )
118
+
119
+ @classmethod
120
+ def eval(cls, input_data: List[str]) -> ResModel:
121
+ assert len(input_data) == 1
122
+ res = ResModel()
123
+ match = re.search(cls.pattern, input_data[0], re.I)
124
+ if match:
125
+ res.error_status = True
126
+ res.error_reason = "Contain license key."
127
+ return res
128
+
129
+
130
+ @Model.rule_register('QUALITY_SIGNAL_FLUENCY', ['default','sft','pretrain','benchmark'])
131
+ class CommonNoPunc(BaseRule):
132
+ """check whether content has paragraph without punctuations"""
133
+
134
+ @classmethod
135
+ def eval(cls, input_data: List[str]) -> ResModel:
136
+ assert len(input_data) == 1
137
+ res = ResModel()
138
+ paragraphs = input_data[0].split('\n')
139
+ max_word_count = 0
140
+ for paragraph in paragraphs:
141
+ if len(paragraph) == 0:
142
+ continue
143
+ sentences = re.split(r'[-–.!?,;•、。!?,;·]', paragraph)
144
+ for sentence in sentences:
145
+ words = sentence.split()
146
+ word_count = len(words)
147
+ if word_count > max_word_count:
148
+ max_word_count = word_count
149
+ text_stat_res = textstat.flesch_reading_ease(input_data[0])
150
+ if int(max_word_count) > 56 and text_stat_res < 20:
151
+ res.error_status = True
152
+ res.error_reason = 'Paragraph without punctuation.'
153
+ return res
154
+
155
+
156
+ @Model.rule_register('QUALITY_SIGNAL_RELEVANCE', ['default','sft','pretrain','benchmark'])
157
+ class CommonSpecialCharacter(BaseRule):
158
+ pattern = r"[�□]|\{\/U\}"
159
+
160
+ @classmethod
161
+ def eval(cls, input_data: List[str]) -> ResModel:
162
+ assert len(input_data) == 1
163
+ res = ResModel()
164
+ matches = re.findall(cls.pattern, input_data[0])
165
+ if matches:
166
+ res.error_status = True
167
+ res.error_reason = matches
168
+ return res
169
+
170
+
171
+ @Model.rule_register("QUALITY_SIGNAL_RELEVANCE", ['zh_all'])
172
+ class CommonWatermark(BaseRule):
173
+ """check whether english prompt produce chinese prediction"""
174
+ key_list = []
175
+
176
+ @classmethod
177
+ def eval(cls, input_data: List[str]) -> ResModel:
178
+ res = ResModel()
179
+ assert len(input_data) == 1
180
+ matches = re.findall('|'.join(cls.key_list), input_data[0])
181
+ if matches:
182
+ res.error_status = True
183
+ res.error_reason = matches
184
+ return res
185
+
186
+
187
+ @Model.rule_register("QUALITY_SIGNAL_COMPLETENESS", ['en_all','pretrain'])
188
+ class CommonWordNumber(BaseRule):
189
+ """check whether the number of word in [50, 100000] """
190
+
191
+ @classmethod
192
+ def eval(cls, input_data: List[str]) -> ResModel:
193
+ assert len(input_data) == 1
194
+ res = ResModel()
195
+ normalized_content = normalize(input_data[0])
196
+ normalized_words = tuple(normalized_content.split())
197
+ num_normalized_words = len(normalized_words)
198
+ if num_normalized_words >= 50 and num_normalized_words < 100000:
199
+ pass
200
+ else:
201
+ res.error_status = True
202
+ res.error_reason = "The number of word is: " + str(num_normalized_words)
203
+ return res
204
+
205
+
206
+ @Model.rule_register('QUALITY_SIGNAL_EFFECTIVENESS', ['en_all','pretrain'])
207
+ class CommonMeanWordLength(BaseRule):
208
+ """check whether the mean length of word in [3, 10] """
209
+
210
+ @classmethod
211
+ def eval(cls, input_data: List[str]) -> ResModel:
212
+ assert len(input_data) == 1
213
+ res = ResModel()
214
+ normalized_content = normalize(input_data[0])
215
+ normalized_words = tuple(normalized_content.split())
216
+ num_normalized_words = len(normalized_words)
217
+ if num_normalized_words == 0:
218
+ return res
219
+
220
+ num_chars = float(sum(map(len, normalized_words)))
221
+ mean_length = num_chars / num_normalized_words
222
+ mean_length = round(mean_length, 2)
223
+ if mean_length >= 3 and mean_length < 10:
224
+ pass
225
+ else:
226
+ res.error_status = True
227
+ res.error_reason = "The mean length of word is: " + str(mean_length)
228
+ return res
229
+
230
+
231
+ @Model.rule_register('QUALITY_SIGNAL_EFFECTIVENESS', ['en_all','sft','pretrain','benchmark'])
232
+ class CommonSymbolWordRatio(BaseRule):
233
+ """check whether the ratio of symbol / word is > 0.1"""
234
+ key_list = ["#", "...", "…"]
235
+
236
+ @classmethod
237
+ def eval(cls, input_data: List[str]) -> ResModel:
238
+ assert len(input_data) == 1
239
+ res = ResModel()
240
+ raw_content = input_data[0]
241
+ raw_words = tuple(WordPunctTokenizer().tokenize(raw_content))
242
+ num_raw_words = len(raw_words)
243
+ if num_raw_words == 0:
244
+ return res
245
+
246
+ num_words = num_raw_words
247
+ num_symbols = float(sum(
248
+ raw_content.count(x) for x in cls.key_list
249
+ ))
250
+
251
+ ratio = num_symbols / num_words
252
+ if ratio > 0.4:
253
+ res.error_status = True
254
+ res.error_reason = "The ratio of symbol / word is: " + str(ratio)
255
+ return res
256
+
257
+
258
+ @Model.rule_register("QUALITY_SIGNAL_EFFECTIVENESS", ['en_all','pretrain'])
259
+ class CommonAlphWords(BaseRule):
260
+ """check whether the ratio of words that contain at least one alphabetic character > 0.8 """
261
+
262
+ @classmethod
263
+ def eval(cls, input_data: List[str]) -> ResModel:
264
+ assert len(input_data) == 1
265
+ res = ResModel()
266
+ raw_content = input_data[0]
267
+ raw_words = tuple(WordPunctTokenizer().tokenize(raw_content))
268
+ num_raw_words = len(raw_words)
269
+ if num_raw_words == 0:
270
+ return res
271
+
272
+ ALPH_REGEX = re.compile(r"[a-zA-Z]")
273
+ num_words = num_raw_words
274
+ num_words_with_alpha = float(sum(
275
+ int(ALPH_REGEX.search(word) is not None)
276
+ for word in raw_words
277
+ ))
278
+ ratio = num_words_with_alpha / num_words
279
+ if ratio > 0.8:
280
+ pass
281
+ else:
282
+ res.error_status = True
283
+ res.error_reason = "The ratio of words that contain at least one alphabetic character is: " + str(ratio)
284
+ return res
285
+
286
+
287
+ @Model.rule_register('QUALITY_SIGNAL_EFFECTIVENESS', ['pretrain'])
288
+ class CommonStopWord(BaseRule):
289
+ """check whether the ratio of stop word > 2"""
290
+
291
+ @classmethod
292
+ def eval(cls, input_data: List[str]) -> ResModel:
293
+ assert len(input_data) == 1
294
+ res = ResModel()
295
+ raw_content = input_data[0]
296
+ raw_words = tuple(WordPunctTokenizer().tokenize(raw_content))
297
+ num_raw_words = len(raw_words)
298
+ if num_raw_words == 0:
299
+ return res
300
+
301
+ STOP_WORDS = get_stop_words("en")
302
+ num_stop_words = sum(
303
+ map(lambda w: w in STOP_WORDS, raw_words)
304
+ )
305
+ ratio = num_stop_words / num_raw_words
306
+ if ratio > 0.4:
307
+ res.error_status = True
308
+ res.error_reason = "The ratio of stop words is: " + str(ratio)
309
+ return res
310
+
311
+
312
+ @Model.rule_register("QUALITY_SIGNAL_COMPLETENESS", ['pretrain'])
313
+ class CommonSentenceNumber(BaseRule):
314
+ """check whether the number of sentence > 3 """
315
+
316
+ @classmethod
317
+ def eval(cls, input_data: List[str]) -> ResModel:
318
+ assert len(input_data) == 1
319
+ res = ResModel()
320
+ raw_content = input_data[0]
321
+
322
+ SENT_PATTERN = re.compile(r'\b[^.!?]+[.!?]*', flags=re.UNICODE)
323
+ num_sentence = len(SENT_PATTERN.findall(raw_content))
324
+ if num_sentence <= 3:
325
+ res.error_status = True
326
+ res.error_reason = "The number of sentence is: " + str(num_sentence)
327
+ return res
328
+
329
+
330
+ @Model.rule_register("QUALITY_SIGNAL_UNDERSTANDABILITY", [])
331
+ class CommonCurlyBracket(BaseRule):
332
+ """check whether content contains curly bracket: { or } """
333
+ pattern = "[{}]"
334
+
335
+ @classmethod
336
+ def eval(cls, input_data: List[str]) -> ResModel:
337
+ assert len(input_data) == 1
338
+ res = ResModel()
339
+ matches = re.findall(cls.pattern, input_data[0])
340
+ if matches:
341
+ res.error_status = True
342
+ res.error_reason = matches
343
+ return res
344
+
345
+
346
+ @Model.rule_register("QUALITY_SIGNAL_UNDERSTANDABILITY", ['pretrain'])
347
+ class CommonCapitalWords(BaseRule):
348
+ """check whether capital words ratio > 0.1 """
349
+
350
+ @classmethod
351
+ def eval(cls, input_data: List[str]) -> ResModel:
352
+ assert len(input_data) == 1
353
+ res = ResModel()
354
+ raw_content = input_data[0]
355
+ raw_words = tuple(WordPunctTokenizer().tokenize(raw_content))
356
+ num_raw_words = len(raw_words)
357
+ if num_raw_words == 0:
358
+ return res
359
+
360
+ num_words = num_raw_words
361
+ num_capital_words = sum([word.isupper() for word in raw_words])
362
+ print(num_capital_words)
363
+ ratio = num_capital_words / num_words
364
+ if ratio > 0.1:
365
+ res.error_status = True
366
+ res.error_reason = "The ratio of capital words is: " + str(ratio)
367
+ return res
368
+
369
+
370
+ @Model.rule_register("QUALITY_SIGNAL_EFFECTIVENESS", ['sft','pretrain','benchmark'])
371
+ class CommonLoremIpsum(BaseRule):
372
+ """check whether the ratio of lorem ipsum < 3e-08 """
373
+
374
+ @classmethod
375
+ def eval(cls, input_data: List[str]) -> ResModel:
376
+ assert len(input_data) == 1
377
+ res = ResModel()
378
+ normalized_content = normalize(input_data[0])
379
+ num_normalized_content = len(normalized_content)
380
+ if num_normalized_content == 0:
381
+ return res
382
+
383
+ SEARCH_REGEX = re.compile(r"lorem ipsum", re.IGNORECASE)
384
+ num_occurrences = len(SEARCH_REGEX.findall(normalized_content))
385
+ ratio = num_occurrences / num_normalized_content
386
+ if ratio > 3e-08:
387
+ res.error_status = True
388
+ res.error_reason = "The ratio of lorem ipsum is: " + str(ratio)
389
+ return res
390
+
391
+
392
+ @Model.rule_register("QUALITY_SIGNAL_UNDERSTANDABILITY", ['pretrain'])
393
+ class CommonUniqueWords(BaseRule):
394
+ """check whether the ratio of unique words > 0.1"""
395
+
396
+ @classmethod
397
+ def eval(cls, input_data: List[str]) -> ResModel:
398
+ assert len(input_data) == 1
399
+ res = ResModel()
400
+ normalized_content = normalize(input_data[0])
401
+ normalized_words = tuple(normalized_content.split())
402
+ num_normalized_words = len(normalized_words)
403
+ if num_normalized_words == 0:
404
+ return res
405
+
406
+ num_words = num_normalized_words
407
+ num_unique_words = len(set(normalized_words))
408
+ ratio = num_unique_words / num_words
409
+ if ratio > 0.1:
410
+ pass
411
+ else:
412
+ res.error_status = True
413
+ res.error_reason = "The ratio of unique words is: " + str(ratio)
414
+ return res
415
+
416
+
417
+ @Model.rule_register("QUALITY_SIGNAL_EFFECTIVENESS", ['pretrain'])
418
+ class CommonCharNumber(BaseRule):
419
+ """check whether the number of char > 200 """
420
+ threshold = 200
421
+
422
+ @classmethod
423
+ def eval(cls, input_data: List[str]) -> ResModel:
424
+ assert len(input_data) == 1
425
+ res = ResModel()
426
+ text = input_data[0]
427
+ text = text.strip()
428
+ text = text.replace(" ", "")
429
+ text = text.replace("\n", "")
430
+ text = text.replace("\t", "")
431
+ num_char = len(text)
432
+ if num_char < cls.threshold:
433
+ res.error_status = True
434
+ res.error_reason = "The number of char is: " + str(num_char)
435
+ return res
436
+
437
+
438
+ @Model.rule_register("QUALITY_SIGNAL_UNDERSTANDABILITY", ['sft','pretrain','benchmark'])
439
+ class CommonLineStartWithBulletpoint(BaseRule):
440
+ """check whether lines start with bulletpoint. """
441
+ key_list = [
442
+ "\u2022", # bullet point
443
+ "\u2023", # triangular bullet point
444
+ "\u25B6", # black right pointing triangle
445
+ "\u25C0", # black left pointing triangle
446
+ "\u25E6", # white bullet point
447
+ "\u25A0", # black square
448
+ "\u25A1", # white square
449
+ "\u25AA", # black small square
450
+ "\u25AB", # white small square
451
+ "\u2013", # en dash
452
+ ]
453
+
454
+ @classmethod
455
+ def eval(cls, input_data: List[str]) -> ResModel:
456
+ assert len(input_data) == 1
457
+ res = ResModel()
458
+ raw_content = input_data[0]
459
+ raw_lines: Tuple[TextSlice] = split_paragraphs(
460
+ text=raw_content, normalizer=lambda x: x, remove_empty=True
461
+ )
462
+ num_lines = len(raw_lines)
463
+ if num_lines == 0:
464
+ return res
465
+
466
+ num_occurrences = sum([line.text.lstrip().startswith(tuple(cls.key_list)) for line in raw_lines])
467
+ ratio = num_occurrences / num_lines
468
+ if ratio > 0.9:
469
+ res.error_status = True
470
+ res.error_reason = "The ratio of lines start with bulletpoint is: " + str(ratio)
471
+ return res
472
+
473
+
474
+ @Model.rule_register("QUALITY_SIGNAL_COMPLETENESS", ['sft','pretrain','benchmark'])
475
+ class CommonLineEndWithEllipsis(BaseRule):
476
+ """check whether lines end with ellipsis. """
477
+ key_list = ["...", "…"]
478
+
479
+ @classmethod
480
+ def eval(cls, input_data: List[str]) -> ResModel:
481
+ assert len(input_data) == 1
482
+ res = ResModel()
483
+ raw_content = input_data[0]
484
+ raw_lines: Tuple[TextSlice] = split_paragraphs(
485
+ text=raw_content, normalizer=lambda x: x, remove_empty=True
486
+ )
487
+ num_lines = len(raw_lines)
488
+ if num_lines == 0:
489
+ return res
490
+
491
+ num_occurrences = sum([line.text.rstrip().endswith(tuple(cls.key_list)) for line in raw_lines])
492
+ ratio = num_occurrences / num_lines
493
+ if ratio > 0.3:
494
+ res.error_status = True
495
+ res.error_reason = "The ratio of lines end with ellipsis is: " + str(ratio)
496
+ return res
497
+
498
+
499
+ @Model.rule_register("QUALITY_SIGNAL_COMPLETENESS", ['pretrain'])
500
+ class CommonLineEndWithTerminal(BaseRule):
501
+ """check whether lines end with terminal punctuation mark. """
502
+ key_list = [".", "!", "?", "”", "\""]
503
+
504
+ @classmethod
505
+ def eval(cls, input_data: List[str]) -> ResModel:
506
+ assert len(input_data) == 1
507
+ res = ResModel()
508
+ raw_content = input_data[0]
509
+ raw_lines: Tuple[TextSlice] = split_paragraphs(
510
+ text=raw_content, normalizer=lambda x: x, remove_empty=True
511
+ )
512
+ num_lines = len(raw_lines)
513
+ if num_lines == 0:
514
+ return res
515
+
516
+ num_occurrences = sum([line.text.rstrip().endswith(tuple(cls.key_list)) for line in raw_lines])
517
+ ratio = num_occurrences / num_lines
518
+ if ratio < 0.6:
519
+ res.error_status = True
520
+ res.error_reason = "The ratio of lines end with terminal punctuation mark is: " + str(ratio)
521
+ return res
522
+
523
+
524
+ @Model.rule_register("QUALITY_SIGNAL_EFFECTIVENESS", ['sft','pretrain','benchmark'])
525
+ class CommonLineWithJavascript(BaseRule):
526
+ """check whether line with the word Javascript. """
527
+
528
+ @classmethod
529
+ def eval(cls, input_data: List[str]) -> ResModel:
530
+ assert len(input_data) == 1
531
+ res = ResModel()
532
+ raw_content = input_data[0]
533
+ normalized_lines: Tuple[TextSlice] = split_paragraphs(
534
+ text=raw_content, normalizer=normalize, remove_empty=True
535
+ )
536
+ num_lines = len(normalized_lines)
537
+ if num_lines == 0:
538
+ return res
539
+
540
+ num_occurrences = sum(['javascript' in line.text for line in normalized_lines])
541
+ num_not_occur = num_lines - num_occurrences
542
+ if num_not_occur < 3 and num_lines > 3:
543
+ res.error_status = True
544
+ res.error_reason = "The lines with the word Javascript is: " + str(num_occurrences)
545
+ return res
546
+
547
+
548
+ if __name__ == '__main__':
549
+ content = "DNA stands for deoxyribonucleic acid."
550
+ tmp = CommonCapitalWords().eval([content])
551
+ print(tmp)
@@ -0,0 +1,81 @@
1
+ import numpy as np
2
+ from PIL import Image
3
+ from typing import List
4
+
5
+ from dingo.model.model import Model
6
+ from dingo.model.rule.base import ResModel, BaseRule
7
+ from dingo.model.rule.util import *
8
+
9
+ try:
10
+ import torch
11
+ except ModuleNotFoundError as e:
12
+ raise ModuleNotFoundError("You need to install `torch`, try `pip install torch`")
13
+ try:
14
+ import pyiqa
15
+ except ModuleNotFoundError as e:
16
+ raise ModuleNotFoundError("You need to install `pyiqa`, try `pip install pyiqa`")
17
+
18
+ @Model.rule_register('QUALITY_SIGNAL_EFFECTIVENESS', [])
19
+ class ImageValid(BaseRule):
20
+ """check whether image is not all white or black"""
21
+ @classmethod
22
+ def eval(cls, input_data: List[str]) -> ResModel:
23
+ res = ResModel()
24
+ img = Image.open(input_data[0])
25
+ img_new = img.convert("RGB")
26
+ img_np = np.asarray(img_new)
27
+ if np.all(img_np == (255, 255, 255)) or np.all(img_np == (0, 0, 0)):
28
+ res.error_status = True
29
+ res.error_reason = 'Image is not valid: all white or black'
30
+ img.close()
31
+ img_new.close()
32
+ return res
33
+
34
+ @Model.rule_register('QUALITY_SIGNAL_EFFECTIVENESS', [])
35
+ class ImageSizeValid(BaseRule):
36
+ """check whether image ratio of width to height is valid"""
37
+ @classmethod
38
+ def eval(cls, input_data: List[str]) -> ResModel:
39
+ res = ResModel()
40
+ img = Image.open(input_data[0])
41
+ width, height = img.size
42
+ aspect_ratio = width / height
43
+ if aspect_ratio > 4 or aspect_ratio < 0.25:
44
+ res.error_status = True
45
+ res.error_reason = 'Image size is not valid, the ratio of width to height: ' + str(aspect_ratio)
46
+ img.close()
47
+ return res
48
+
49
+ @Model.rule_register('QUALITY_SIGNAL_EFFECTIVENESS', [])
50
+ class ImageQuality(BaseRule):
51
+ """check whether image quality is good."""
52
+ threshold = 5.5
53
+
54
+ @classmethod
55
+ def eval(cls, input_data: List[str]) -> ResModel:
56
+ res = ResModel()
57
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
58
+ iqa_metric = pyiqa.create_metric('nima', device=device)
59
+ score_fr = iqa_metric(input_data[0])
60
+ score = score_fr.item()
61
+ print(score)
62
+ if score < cls.threshold:
63
+ res.error_status = True
64
+ res.error_reason = 'Image quality is not satisfied, ratio: ' + str(score)
65
+ return res
66
+
67
+ # @Model.rule_register('QUALITY_SIGNAL_SECURITY', [])
68
+ # class ImageQRCode(BaseRule):
69
+ # """check whether image contains QR code."""
70
+ # @classmethod
71
+ # def eval(cls, input_data: List[str]) -> ResModel:
72
+ # res = ResModel()
73
+ # img = cv2.imread(input_data[0])
74
+ # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
75
+ # scanner = zbar.Scanner()
76
+ # tmp = scanner.scan(gray)
77
+ # if len(tmp) != 0:
78
+ # if tmp[0].type == 'QR-Code':
79
+ # res.error_status = True
80
+ # res.error_reason = tmp[0].data
81
+ # return res
@@ -0,0 +1,39 @@
1
+ from typing import List
2
+ import langid
3
+
4
+ from dingo.model import Model
5
+ from dingo.model.rule.base import ResModel, BaseRule
6
+
7
+
8
+ @Model.rule_register("QUALITY_SIGNAL_EFFECTIVENESS", [])
9
+ class PromptChineseProduceEnglish(BaseRule):
10
+ """check whether chinese prompt produce english prediction"""
11
+ rule_type = 'prompt'
12
+
13
+ @classmethod
14
+ def eval(cls, input_data: List[str]) -> ResModel:
15
+ res = ResModel()
16
+ assert len(input_data) == 2
17
+ lan_prompt = langid.classify(input_data[0])[0]
18
+ lan_prediction = langid.classify(input_data[1])[0]
19
+ if lan_prompt == 'zh' and lan_prediction == 'en':
20
+ res.error_status = True
21
+ res.error_reason = 'Chinese prompt, generate English content.'
22
+ return res
23
+
24
+
25
+ @Model.rule_register("QUALITY_SIGNAL_EFFECTIVENESS", [])
26
+ class PromptEnglishProduceChinese(BaseRule):
27
+ """check whether english prompt produce chinese prediction"""
28
+ rule_type = 'prompt'
29
+
30
+ @classmethod
31
+ def eval(cls, input_data: List[str]) -> ResModel:
32
+ res = ResModel()
33
+ assert len(input_data) == 2
34
+ lan_prompt = langid.classify(input_data[0])[0]
35
+ lan_prediction = langid.classify(input_data[1])[0]
36
+ if lan_prompt == 'en' and lan_prediction == 'zh':
37
+ res.error_status = True
38
+ res.error_reason = 'English prompt, generate Chinese content'
39
+ return res