cnhkmcp 2.3.6__py3-none-any.whl → 2.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,7 +25,8 @@ except ImportError:
25
25
  supported_functions = {
26
26
  # Group 类别函数
27
27
  'group_min': {'min_args': 2, 'max_args': 2, 'arg_types': ['expression', 'category']},
28
- 'group_mean': {'min_args': 3, 'max_args': 3, 'arg_types': ['expression', 'expression', 'expression']},
28
+ # group_mean(x, w, group)
29
+ 'group_mean': {'min_args': 3, 'max_args': 3, 'arg_types': ['expression', 'expression', 'category']},
29
30
  'group_median': {'min_args': 2, 'max_args': 2, 'arg_types': ['expression', 'category']},
30
31
  'group_max': {'min_args': 2, 'max_args': 2, 'arg_types': ['expression', 'category']},
31
32
  'group_rank': {'min_args': 2, 'max_args': 2, 'arg_types': ['expression', 'category']},
@@ -612,6 +613,16 @@ class ExpressionValidator:
612
613
  """验证参数类型是否符合预期"""
613
614
  errors = []
614
615
 
616
+ def _is_number_like(node: ASTNode) -> bool:
617
+ if node is None:
618
+ return False
619
+ if node.node_type == 'number':
620
+ return True
621
+ if node.node_type == 'unop' and isinstance(node.value, dict) and node.value.get('op') in {'-', '+'}:
622
+ if node.children and hasattr(node.children[0], 'node_type'):
623
+ return _is_number_like(node.children[0])
624
+ return False
625
+
615
626
  # Unit compatibility check
616
627
  # bucket()/group_cartesian_product() output a derived category (grouping key).
617
628
  # It can only be consumed where a category/grouping key is expected (typically by group_* operators).
@@ -633,7 +644,8 @@ class ExpressionValidator:
633
644
  # 表达式可以是任何有效的AST节点
634
645
  pass
635
646
  elif expected_type == 'number':
636
- if arg.node_type != 'number':
647
+ # 允许 -1 这类一元负号数字常量(解析为 unop(number))
648
+ if not _is_number_like(arg):
637
649
  errors.append(f"参数 {arg_index+1} 应该是一个数字,但得到 {arg.node_type}")
638
650
  elif expected_type == 'boolean':
639
651
  # 布尔值可以是 true/false 或数字(0/1)
@@ -719,19 +731,54 @@ class ExpressionValidator:
719
731
  return cached
720
732
 
721
733
  derived = False
722
- if node.node_type == 'function' and node.value in {'bucket', 'group_cartesian_product'}:
723
- derived = True
734
+ if node.node_type == 'function':
735
+ if node.value in {'bucket', 'group_cartesian_product'}:
736
+ derived = True
737
+ else:
738
+ function_info = supported_functions.get(node.value, {})
739
+ arg_types = function_info.get('arg_types', [])
740
+ param_names = function_info.get('param_names', [])
741
+
742
+ positional_index = 0
743
+ for child in node.children:
744
+ if isinstance(child, dict):
745
+ if child.get('type') == 'named':
746
+ name = child.get('name')
747
+ value = child.get('value')
748
+
749
+ expected_type = None
750
+ if name in param_names:
751
+ param_index = param_names.index(name)
752
+ if param_index < len(arg_types):
753
+ expected_type = arg_types[param_index]
754
+
755
+ # Do not propagate "derived" through allowed category/grouping-key inputs.
756
+ if expected_type == 'category':
757
+ continue
758
+
759
+ if self._is_derived_category(value):
760
+ derived = True
761
+ break
762
+ elif child.get('type') == 'positional':
763
+ value = child.get('value')
764
+ expected_type = arg_types[positional_index] if positional_index < len(arg_types) else None
765
+
766
+ if expected_type != 'category' and self._is_derived_category(value):
767
+ derived = True
768
+ break
769
+ positional_index += 1
770
+ else:
771
+ expected_type = arg_types[positional_index] if positional_index < len(arg_types) else None
772
+ if expected_type != 'category' and self._is_derived_category(child):
773
+ derived = True
774
+ break
775
+ positional_index += 1
724
776
  elif node.node_type in {'unop', 'binop'}:
725
777
  derived = any(
726
778
  self._is_derived_category(child)
727
779
  for child in node.children
728
780
  if hasattr(child, 'node_type')
729
781
  )
730
- elif node.node_type == 'function':
731
- derived = any(
732
- self._is_derived_category(child.get('value')) if isinstance(child, dict) else self._is_derived_category(child)
733
- for child in node.children
734
- )
735
782
 
736
783
  self._derived_category_cache[cache_key] = derived
737
784
  return derived
@@ -865,6 +912,173 @@ class ExpressionValidator:
865
912
  Returns:
866
913
  Tuple[bool, str]: (是否成功, 转换后的表达式或错误信息)
867
914
  """
915
+ def _top_level_equals_positions(stmt: str) -> List[int]:
916
+ """返回所有“顶层赋值”等号位置。
917
+
918
+ 仅统计括号外(()[]{})、引号外、且不属于比较操作符(==,!=,<=,>=)的 '='。
919
+ 这样可以避免把关键字参数(如 rettype=0)误判为赋值语句。
920
+ """
921
+ positions: List[int] = []
922
+ paren_depth = 0
923
+ bracket_depth = 0
924
+ brace_depth = 0
925
+ in_single_quote = False
926
+ in_double_quote = False
927
+ escape = False
928
+
929
+ for i, ch in enumerate(stmt):
930
+ if escape:
931
+ escape = False
932
+ continue
933
+ if ch == '\\':
934
+ escape = True
935
+ continue
936
+
937
+ if in_single_quote:
938
+ if ch == "'":
939
+ in_single_quote = False
940
+ continue
941
+ if in_double_quote:
942
+ if ch == '"':
943
+ in_double_quote = False
944
+ continue
945
+
946
+ if ch == "'":
947
+ in_single_quote = True
948
+ continue
949
+ if ch == '"':
950
+ in_double_quote = True
951
+ continue
952
+
953
+ if ch == '(':
954
+ paren_depth += 1
955
+ continue
956
+ if ch == ')':
957
+ paren_depth = max(0, paren_depth - 1)
958
+ continue
959
+ if ch == '[':
960
+ bracket_depth += 1
961
+ continue
962
+ if ch == ']':
963
+ bracket_depth = max(0, bracket_depth - 1)
964
+ continue
965
+ if ch == '{':
966
+ brace_depth += 1
967
+ continue
968
+ if ch == '}':
969
+ brace_depth = max(0, brace_depth - 1)
970
+ continue
971
+
972
+ if paren_depth or bracket_depth or brace_depth:
973
+ continue
974
+
975
+ if ch != '=':
976
+ continue
977
+
978
+ # 过滤比较操作符(==,!=,<=,>=)
979
+ prev_ch = stmt[i - 1] if i > 0 else ''
980
+ next_ch = stmt[i + 1] if i + 1 < len(stmt) else ''
981
+ if prev_ch in ['=', '!', '<', '>'] or next_ch == '=':
982
+ continue
983
+
984
+ positions.append(i)
985
+
986
+ return positions
987
+
988
+ def _keyword_arg_names(stmt: str):
989
+ """提取函数调用中的命名参数名(如 rettype=0 中的 rettype)。
990
+
991
+ 只收集括号/中括号/大括号内部出现的 name= 形式,避免把脚本级赋值误当作命名参数。
992
+ """
993
+ names = set()
994
+ paren_depth = 0
995
+ bracket_depth = 0
996
+ brace_depth = 0
997
+ in_single_quote = False
998
+ in_double_quote = False
999
+ escape = False
1000
+
1001
+ i = 0
1002
+ while i < len(stmt):
1003
+ ch = stmt[i]
1004
+
1005
+ if escape:
1006
+ escape = False
1007
+ i += 1
1008
+ continue
1009
+ if ch == '\\':
1010
+ escape = True
1011
+ i += 1
1012
+ continue
1013
+
1014
+ if in_single_quote:
1015
+ if ch == "'":
1016
+ in_single_quote = False
1017
+ i += 1
1018
+ continue
1019
+ if in_double_quote:
1020
+ if ch == '"':
1021
+ in_double_quote = False
1022
+ i += 1
1023
+ continue
1024
+
1025
+ if ch == "'":
1026
+ in_single_quote = True
1027
+ i += 1
1028
+ continue
1029
+ if ch == '"':
1030
+ in_double_quote = True
1031
+ i += 1
1032
+ continue
1033
+
1034
+ if ch == '(':
1035
+ paren_depth += 1
1036
+ i += 1
1037
+ continue
1038
+ if ch == ')':
1039
+ paren_depth = max(0, paren_depth - 1)
1040
+ i += 1
1041
+ continue
1042
+ if ch == '[':
1043
+ bracket_depth += 1
1044
+ i += 1
1045
+ continue
1046
+ if ch == ']':
1047
+ bracket_depth = max(0, bracket_depth - 1)
1048
+ i += 1
1049
+ continue
1050
+ if ch == '{':
1051
+ brace_depth += 1
1052
+ i += 1
1053
+ continue
1054
+ if ch == '}':
1055
+ brace_depth = max(0, brace_depth - 1)
1056
+ i += 1
1057
+ continue
1058
+
1059
+ inside_container = bool(paren_depth or bracket_depth or brace_depth)
1060
+
1061
+ if inside_container and (ch.isalpha() or ch == '_'):
1062
+ start = i
1063
+ i += 1
1064
+ while i < len(stmt) and (stmt[i].isalnum() or stmt[i] == '_'):
1065
+ i += 1
1066
+ name = stmt[start:i]
1067
+
1068
+ j = i
1069
+ while j < len(stmt) and stmt[j].isspace():
1070
+ j += 1
1071
+
1072
+ if j < len(stmt) and stmt[j] == '=':
1073
+ next_ch = stmt[j + 1] if j + 1 < len(stmt) else ''
1074
+ if next_ch != '=':
1075
+ names.add(name.lower())
1076
+ continue
1077
+
1078
+ i += 1
1079
+
1080
+ return names
1081
+
868
1082
  # 检查表达式是否以分号结尾
869
1083
  if expression.strip().endswith(';'):
870
1084
  return False, "表达式不能以分号结尾"
@@ -879,51 +1093,13 @@ class ExpressionValidator:
879
1093
 
880
1094
  # 处理每个赋值语句(除了最后一个)
881
1095
  for i, stmt in enumerate(statements[:-1]):
882
- # 检查是否包含赋值符号
883
- if '=' not in stmt:
884
- return False, f"第{i+1}个语句必须是赋值语句(使用=符号)"
885
-
886
- # 检查是否是比较操作符(==, !=, <=, >=)
887
- if any(op in stmt for op in ['==', '!=', '<=', '>=']):
888
- # 如果包含比较操作符,需要确认是否有赋值符号
889
- # 使用临时替换法:将比较操作符替换为临时标记,再检查是否还有=
890
- temp_stmt = stmt
891
- for op in ['==', '!=', '<=', '>=']:
892
- temp_stmt = temp_stmt.replace(op, '---')
893
-
894
- if '=' not in temp_stmt:
895
- return False, f"第{i+1}个语句必须是赋值语句,不能只是比较表达式"
896
-
897
- # 找到第一个=符号(不是比较操作符的一部分)
898
- # 先将比较操作符替换为临时标记,再找=
899
- temp_stmt = stmt
900
- for op in ['==', '!=', '<=', '>=']:
901
- temp_stmt = temp_stmt.replace(op, '---')
902
-
903
- if '=' not in temp_stmt:
1096
+ eq_positions = _top_level_equals_positions(stmt)
1097
+ if not eq_positions:
904
1098
  return False, f"第{i+1}个语句必须是赋值语句(使用=符号)"
905
-
906
- # 找到实际的=位置
907
- equals_pos = temp_stmt.index('=')
908
-
909
- # 在原始语句中找到对应位置
910
- real_equals_pos = 0
911
- temp_count = 0
912
- for char in stmt:
913
- if temp_count == equals_pos:
914
- break
915
- if char in '!<>':
916
- # 检查是否是比较操作符的一部分
917
- if real_equals_pos + 1 < len(stmt) and stmt[real_equals_pos + 1] == '=':
918
- # 是比较操作符,跳过两个字符
919
- real_equals_pos += 2
920
- temp_count += 3 # 因为替换成了三个字符的---
921
- else:
922
- real_equals_pos += 1
923
- temp_count += 1
924
- else:
925
- real_equals_pos += 1
926
- temp_count += 1
1099
+ if len(eq_positions) > 1:
1100
+ return False, f"第{i+1}个语句只能包含一个赋值符号(=)"
1101
+
1102
+ real_equals_pos = eq_positions[0]
927
1103
 
928
1104
  # 分割变量名和值
929
1105
  var_name = stmt[:real_equals_pos].strip()
@@ -940,9 +1116,12 @@ class ExpressionValidator:
940
1116
 
941
1117
  # 检查变量值中使用的变量是否已经定义
942
1118
  # 简单检查:提取所有可能的变量名
1119
+ kw_names = _keyword_arg_names(var_value)
943
1120
  used_vars = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', var_value)
944
1121
  for used_var in used_vars:
945
1122
  used_var_lower = used_var.lower()
1123
+ if used_var_lower in kw_names:
1124
+ continue
946
1125
  if used_var_lower not in variables:
947
1126
  # 检查是否是函数名
948
1127
  if used_var not in supported_functions:
@@ -965,19 +1144,16 @@ class ExpressionValidator:
965
1144
  final_stmt = statements[-1]
966
1145
 
967
1146
  # 检查最后一个语句是否是赋值语句
968
- if '=' in final_stmt:
969
- # 替换比较操作符为临时标记,然后检查是否还有单独的=
970
- temp_stmt = final_stmt
971
- for op in ['==', '!=', '<=', '>=']:
972
- temp_stmt = temp_stmt.replace(op, '---')
973
-
974
- if '=' in temp_stmt:
975
- return False, "最后一个语句不能是赋值语句"
1147
+ if _top_level_equals_positions(final_stmt):
1148
+ return False, "最后一个语句不能是赋值语句"
976
1149
 
977
1150
  # 检查最后一个语句中使用的变量是否已经定义
1151
+ kw_names = _keyword_arg_names(final_stmt)
978
1152
  used_vars = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', final_stmt)
979
1153
  for used_var in used_vars:
980
1154
  used_var_lower = used_var.lower()
1155
+ if used_var_lower in kw_names:
1156
+ continue
981
1157
  if used_var_lower not in variables:
982
1158
  # 检查是否是函数名
983
1159
  if used_var not in supported_functions: