cnhkmcp 2.3.6__py3-none-any.whl → 2.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cnhkmcp/__init__.py CHANGED
@@ -50,7 +50,7 @@ from .untracked.forum_functions import (
50
50
  read_full_forum_post
51
51
  )
52
52
 
53
- __version__ = "2.3.6"
53
+ __version__ = "2.3.8"
54
54
  __author__ = "CNHK"
55
55
  __email__ = "cnhk@example.com"
56
56
 
@@ -25,7 +25,8 @@ except ImportError:
25
25
  supported_functions = {
26
26
  # Group 类别函数
27
27
  'group_min': {'min_args': 2, 'max_args': 2, 'arg_types': ['expression', 'category']},
28
- 'group_mean': {'min_args': 3, 'max_args': 3, 'arg_types': ['expression', 'expression', 'expression']},
28
+ # group_mean(x, w, group)
29
+ 'group_mean': {'min_args': 3, 'max_args': 3, 'arg_types': ['expression', 'expression', 'category']},
29
30
  'group_median': {'min_args': 2, 'max_args': 2, 'arg_types': ['expression', 'category']},
30
31
  'group_max': {'min_args': 2, 'max_args': 2, 'arg_types': ['expression', 'category']},
31
32
  'group_rank': {'min_args': 2, 'max_args': 2, 'arg_types': ['expression', 'category']},
@@ -612,6 +613,16 @@ class ExpressionValidator:
612
613
  """验证参数类型是否符合预期"""
613
614
  errors = []
614
615
 
616
+ def _is_number_like(node: ASTNode) -> bool:
617
+ if node is None:
618
+ return False
619
+ if node.node_type == 'number':
620
+ return True
621
+ if node.node_type == 'unop' and isinstance(node.value, dict) and node.value.get('op') in {'-', '+'}:
622
+ if node.children and hasattr(node.children[0], 'node_type'):
623
+ return _is_number_like(node.children[0])
624
+ return False
625
+
615
626
  # Unit compatibility check
616
627
  # bucket()/group_cartesian_product() output a derived category (grouping key).
617
628
  # It can only be consumed where a category/grouping key is expected.
@@ -632,7 +643,8 @@ class ExpressionValidator:
632
643
  # 表达式可以是任何有效的AST节点
633
644
  pass
634
645
  elif expected_type == 'number':
635
- if arg.node_type != 'number':
646
+ # 允许 -1 这类一元负号数字常量(解析为 unop(number))
647
+ if not _is_number_like(arg):
636
648
  errors.append(f"参数 {arg_index+1} 应该是一个数字,但得到 {arg.node_type}")
637
649
  elif expected_type == 'boolean':
638
650
  # 布尔值可以是 true/false 或数字(0/1)
@@ -715,19 +727,53 @@ class ExpressionValidator:
715
727
  return cached
716
728
 
717
729
  derived = False
718
- if node.node_type == 'function' and node.value in {'bucket', 'group_cartesian_product'}:
719
- derived = True
730
+ if node.node_type == 'function':
731
+ if node.value in {'bucket', 'group_cartesian_product'}:
732
+ derived = True
733
+ else:
734
+ function_info = supported_functions.get(node.value, {})
735
+ arg_types = function_info.get('arg_types', [])
736
+ param_names = function_info.get('param_names', [])
737
+
738
+ positional_index = 0
739
+ for child in node.children:
740
+ if isinstance(child, dict):
741
+ if child.get('type') == 'named':
742
+ name = child.get('name')
743
+ value = child.get('value')
744
+
745
+ expected_type = None
746
+ if name in param_names:
747
+ param_index = param_names.index(name)
748
+ if param_index < len(arg_types):
749
+ expected_type = arg_types[param_index]
750
+
751
+ if expected_type == 'category':
752
+ continue
753
+
754
+ if self._is_derived_category(value):
755
+ derived = True
756
+ break
757
+ elif child.get('type') == 'positional':
758
+ value = child.get('value')
759
+ expected_type = arg_types[positional_index] if positional_index < len(arg_types) else None
760
+
761
+ if expected_type != 'category' and self._is_derived_category(value):
762
+ derived = True
763
+ break
764
+ positional_index += 1
765
+ else:
766
+ expected_type = arg_types[positional_index] if positional_index < len(arg_types) else None
767
+ if expected_type != 'category' and self._is_derived_category(child):
768
+ derived = True
769
+ break
770
+ positional_index += 1
720
771
  elif node.node_type in {'unop', 'binop'}:
721
772
  derived = any(
722
773
  self._is_derived_category(child)
723
774
  for child in node.children
724
775
  if hasattr(child, 'node_type')
725
776
  )
726
- elif node.node_type == 'function':
727
- derived = any(
728
- self._is_derived_category(child.get('value')) if isinstance(child, dict) else self._is_derived_category(child)
729
- for child in node.children
730
- )
731
777
 
732
778
  self._derived_category_cache[cache_key] = derived
733
779
  return derived
@@ -859,6 +905,172 @@ class ExpressionValidator:
859
905
  Returns:
860
906
  Tuple[bool, str]: (是否成功, 转换后的表达式或错误信息)
861
907
  """
908
+ def _top_level_equals_positions(stmt: str) -> List[int]:
909
+ """返回所有“顶层赋值”等号位置。
910
+
911
+ 仅统计括号外(()[]{})、引号外、且不属于比较操作符(==,!=,<=,>=)的 '='。
912
+ 这样可以避免把关键字参数(如 rettype=0)误判为赋值语句。
913
+ """
914
+ positions: List[int] = []
915
+ paren_depth = 0
916
+ bracket_depth = 0
917
+ brace_depth = 0
918
+ in_single_quote = False
919
+ in_double_quote = False
920
+ escape = False
921
+
922
+ for i, ch in enumerate(stmt):
923
+ if escape:
924
+ escape = False
925
+ continue
926
+ if ch == '\\':
927
+ escape = True
928
+ continue
929
+
930
+ if in_single_quote:
931
+ if ch == "'":
932
+ in_single_quote = False
933
+ continue
934
+ if in_double_quote:
935
+ if ch == '"':
936
+ in_double_quote = False
937
+ continue
938
+
939
+ if ch == "'":
940
+ in_single_quote = True
941
+ continue
942
+ if ch == '"':
943
+ in_double_quote = True
944
+ continue
945
+
946
+ if ch == '(':
947
+ paren_depth += 1
948
+ continue
949
+ if ch == ')':
950
+ paren_depth = max(0, paren_depth - 1)
951
+ continue
952
+ if ch == '[':
953
+ bracket_depth += 1
954
+ continue
955
+ if ch == ']':
956
+ bracket_depth = max(0, bracket_depth - 1)
957
+ continue
958
+ if ch == '{':
959
+ brace_depth += 1
960
+ continue
961
+ if ch == '}':
962
+ brace_depth = max(0, brace_depth - 1)
963
+ continue
964
+
965
+ if paren_depth or bracket_depth or brace_depth:
966
+ continue
967
+
968
+ if ch != '=':
969
+ continue
970
+
971
+ prev_ch = stmt[i - 1] if i > 0 else ''
972
+ next_ch = stmt[i + 1] if i + 1 < len(stmt) else ''
973
+ if prev_ch in ['=', '!', '<', '>'] or next_ch == '=':
974
+ continue
975
+
976
+ positions.append(i)
977
+
978
+ return positions
979
+
980
+ def _keyword_arg_names(stmt: str):
981
+ """提取函数调用中的命名参数名(如 rettype=0 中的 rettype)。
982
+
983
+ 只收集括号/中括号/大括号内部出现的 name= 形式,避免把脚本级赋值误当作命名参数。
984
+ """
985
+ names = set()
986
+ paren_depth = 0
987
+ bracket_depth = 0
988
+ brace_depth = 0
989
+ in_single_quote = False
990
+ in_double_quote = False
991
+ escape = False
992
+
993
+ i = 0
994
+ while i < len(stmt):
995
+ ch = stmt[i]
996
+
997
+ if escape:
998
+ escape = False
999
+ i += 1
1000
+ continue
1001
+ if ch == '\\':
1002
+ escape = True
1003
+ i += 1
1004
+ continue
1005
+
1006
+ if in_single_quote:
1007
+ if ch == "'":
1008
+ in_single_quote = False
1009
+ i += 1
1010
+ continue
1011
+ if in_double_quote:
1012
+ if ch == '"':
1013
+ in_double_quote = False
1014
+ i += 1
1015
+ continue
1016
+
1017
+ if ch == "'":
1018
+ in_single_quote = True
1019
+ i += 1
1020
+ continue
1021
+ if ch == '"':
1022
+ in_double_quote = True
1023
+ i += 1
1024
+ continue
1025
+
1026
+ if ch == '(':
1027
+ paren_depth += 1
1028
+ i += 1
1029
+ continue
1030
+ if ch == ')':
1031
+ paren_depth = max(0, paren_depth - 1)
1032
+ i += 1
1033
+ continue
1034
+ if ch == '[':
1035
+ bracket_depth += 1
1036
+ i += 1
1037
+ continue
1038
+ if ch == ']':
1039
+ bracket_depth = max(0, bracket_depth - 1)
1040
+ i += 1
1041
+ continue
1042
+ if ch == '{':
1043
+ brace_depth += 1
1044
+ i += 1
1045
+ continue
1046
+ if ch == '}':
1047
+ brace_depth = max(0, brace_depth - 1)
1048
+ i += 1
1049
+ continue
1050
+
1051
+ inside_container = bool(paren_depth or bracket_depth or brace_depth)
1052
+
1053
+ if inside_container and (ch.isalpha() or ch == '_'):
1054
+ start = i
1055
+ i += 1
1056
+ while i < len(stmt) and (stmt[i].isalnum() or stmt[i] == '_'):
1057
+ i += 1
1058
+ name = stmt[start:i]
1059
+
1060
+ j = i
1061
+ while j < len(stmt) and stmt[j].isspace():
1062
+ j += 1
1063
+
1064
+ if j < len(stmt) and stmt[j] == '=':
1065
+ next_ch = stmt[j + 1] if j + 1 < len(stmt) else ''
1066
+ if next_ch != '=':
1067
+ names.add(name.lower())
1068
+ continue
1069
+
1070
+ i += 1
1071
+
1072
+ return names
1073
+
862
1074
  # 检查表达式是否以分号结尾
863
1075
  if expression.strip().endswith(';'):
864
1076
  return False, "表达式不能以分号结尾"
@@ -873,51 +1085,13 @@ class ExpressionValidator:
873
1085
 
874
1086
  # 处理每个赋值语句(除了最后一个)
875
1087
  for i, stmt in enumerate(statements[:-1]):
876
- # 检查是否包含赋值符号
877
- if '=' not in stmt:
878
- return False, f"第{i+1}个语句必须是赋值语句(使用=符号)"
879
-
880
- # 检查是否是比较操作符(==, !=, <=, >=)
881
- if any(op in stmt for op in ['==', '!=', '<=', '>=']):
882
- # 如果包含比较操作符,需要确认是否有赋值符号
883
- # 使用临时替换法:将比较操作符替换为临时标记,再检查是否还有=
884
- temp_stmt = stmt
885
- for op in ['==', '!=', '<=', '>=']:
886
- temp_stmt = temp_stmt.replace(op, '---')
887
-
888
- if '=' not in temp_stmt:
889
- return False, f"第{i+1}个语句必须是赋值语句,不能只是比较表达式"
890
-
891
- # 找到第一个=符号(不是比较操作符的一部分)
892
- # 先将比较操作符替换为临时标记,再找=
893
- temp_stmt = stmt
894
- for op in ['==', '!=', '<=', '>=']:
895
- temp_stmt = temp_stmt.replace(op, '---')
896
-
897
- if '=' not in temp_stmt:
1088
+ eq_positions = _top_level_equals_positions(stmt)
1089
+ if not eq_positions:
898
1090
  return False, f"第{i+1}个语句必须是赋值语句(使用=符号)"
899
-
900
- # 找到实际的=位置
901
- equals_pos = temp_stmt.index('=')
902
-
903
- # 在原始语句中找到对应位置
904
- real_equals_pos = 0
905
- temp_count = 0
906
- for char in stmt:
907
- if temp_count == equals_pos:
908
- break
909
- if char in '!<>':
910
- # 检查是否是比较操作符的一部分
911
- if real_equals_pos + 1 < len(stmt) and stmt[real_equals_pos + 1] == '=':
912
- # 是比较操作符,跳过两个字符
913
- real_equals_pos += 2
914
- temp_count += 3 # 因为替换成了三个字符的---
915
- else:
916
- real_equals_pos += 1
917
- temp_count += 1
918
- else:
919
- real_equals_pos += 1
920
- temp_count += 1
1091
+ if len(eq_positions) > 1:
1092
+ return False, f"第{i+1}个语句只能包含一个赋值符号(=)"
1093
+
1094
+ real_equals_pos = eq_positions[0]
921
1095
 
922
1096
  # 分割变量名和值
923
1097
  var_name = stmt[:real_equals_pos].strip()
@@ -934,9 +1108,12 @@ class ExpressionValidator:
934
1108
 
935
1109
  # 检查变量值中使用的变量是否已经定义
936
1110
  # 简单检查:提取所有可能的变量名
1111
+ kw_names = _keyword_arg_names(var_value)
937
1112
  used_vars = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', var_value)
938
1113
  for used_var in used_vars:
939
1114
  used_var_lower = used_var.lower()
1115
+ if used_var_lower in kw_names:
1116
+ continue
940
1117
  if used_var_lower not in variables:
941
1118
  # 检查是否是函数名
942
1119
  if used_var not in supported_functions:
@@ -959,19 +1136,16 @@ class ExpressionValidator:
959
1136
  final_stmt = statements[-1]
960
1137
 
961
1138
  # 检查最后一个语句是否是赋值语句
962
- if '=' in final_stmt:
963
- # 替换比较操作符为临时标记,然后检查是否还有单独的=
964
- temp_stmt = final_stmt
965
- for op in ['==', '!=', '<=', '>=']:
966
- temp_stmt = temp_stmt.replace(op, '---')
967
-
968
- if '=' in temp_stmt:
969
- return False, "最后一个语句不能是赋值语句"
1139
+ if _top_level_equals_positions(final_stmt):
1140
+ return False, "最后一个语句不能是赋值语句"
970
1141
 
971
1142
  # 检查最后一个语句中使用的变量是否已经定义
1143
+ kw_names = _keyword_arg_names(final_stmt)
972
1144
  used_vars = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', final_stmt)
973
1145
  for used_var in used_vars:
974
1146
  used_var_lower = used_var.lower()
1147
+ if used_var_lower in kw_names:
1148
+ continue
975
1149
  if used_var_lower not in variables:
976
1150
  # 检查是否是函数名
977
1151
  if used_var not in supported_functions: