From f1ee2cb2900d27fb376831e11c8d33ccd97fb604 Mon Sep 17 00:00:00 2001 From: Alexander Medvednikov Date: Wed, 25 Mar 2026 16:42:26 +0300 Subject: [PATCH] checker: fix type-casting issue when providing generic function as argument (fixes #21132) --- vlib/v/checker/checker.v | 244 ++++++++++++------ ...generic_function_argument_inference_test.v | 26 ++ 2 files changed, 185 insertions(+), 85 deletions(-) create mode 100644 vlib/v/tests/generics/generic_function_argument_inference_test.v diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index ac0cedc95..9f31b7056 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -5349,129 +5349,207 @@ fn (mut c Checker) at_expr(mut node ast.AtExpr) ast.Type { return ast.string_type } -fn (mut c Checker) infer_fn_value_generic_type(pattern ast.Type, actual ast.Type, generic_names []string, mut inferred map[string]ast.Type) bool { - pattern_typ := c.unwrap_generic(pattern) - actual_typ := c.unwrap_generic(actual) - pattern_root := pattern_typ.clear_option_and_result().set_nr_muls(0) - pattern_root_sym := c.table.sym(pattern_root) - if pattern_root.has_flag(.generic) && pattern_root_sym.name in generic_names { - if pattern_typ.has_flag(.option) && !actual_typ.has_flag(.option) { - return false - } - if pattern_typ.has_flag(.result) && !actual_typ.has_flag(.result) { - return false - } - if actual_typ.nr_muls() < pattern_typ.nr_muls() { +fn (mut c Checker) same_inferred_fn_value_type(left ast.Type, right ast.Type) bool { + if left == right { + return true + } + if left.nr_muls() != right.nr_muls() { + return false + } + if left.has_flag(.option) != right.has_flag(.option) { + return false + } + if left.has_flag(.result) != right.has_flag(.result) { + return false + } + if left.has_flag(.variadic) != right.has_flag(.variadic) { + return false + } + if left.share() != right.share() { + return false + } + left_base := c.table.unaliased_type(left.clear_flags(.option, .result, .variadic, + .shared_f, .atomic_f).clear_ref()) + right_base := c.table.unaliased_type(right.clear_flags(.option, .result, .variadic, + .shared_f, .atomic_f).clear_ref()) + return left_base == right_base +} + +fn (mut c Checker) bind_inferred_fn_value_type(mut inferred map[string]ast.Type, generic_name string, concrete_type ast.Type) bool { + if generic_name in inferred { + return c.same_inferred_fn_value_type(inferred[generic_name], concrete_type) + } + inferred[generic_name] = concrete_type + return true +} + +fn (mut c Checker) infer_fn_value_concrete_type(mut inferred map[string]ast.Type, generic_names []string, generic_type ast.Type, concrete_type ast.Type) bool { + if generic_type.has_flag(.option) != concrete_type.has_flag(.option) { + return false + } + if generic_type.has_flag(.result) != concrete_type.has_flag(.result) { + return false + } + if generic_type.has_flag(.variadic) != concrete_type.has_flag(.variadic) { + return false + } + if generic_type.share() != concrete_type.share() { + return false + } + mut template_type := generic_type.clear_flags(.option, .result, .variadic, .shared_f, + .atomic_f) + mut actual_type := concrete_type.clear_flags(.option, .result, .variadic, .shared_f, + .atomic_f) + template_sym := c.table.sym(template_type) + if template_sym.name in generic_names { + if template_type.nr_muls() > actual_type.nr_muls() { return false } - mut inferred_typ := actual_typ - if pattern_typ.has_flag(.option) { - inferred_typ = inferred_typ.clear_flag(.option) - } - if pattern_typ.has_flag(.result) { - inferred_typ = inferred_typ.clear_flag(.result) - } - inferred_typ = inferred_typ.set_nr_muls(actual_typ.nr_muls() - pattern_typ.nr_muls()) - if pattern_root_sym.name in inferred { - return c.table.unaliased_type(inferred[pattern_root_sym.name]) == c.table.unaliased_type(inferred_typ) + inferred_type := if template_type.nr_muls() > 0 { + actual_type.set_nr_muls(actual_type.nr_muls() - template_type.nr_muls()) + } else { + actual_type } - inferred[pattern_root_sym.name] = inferred_typ - return true + return c.bind_inferred_fn_value_type(mut inferred, template_sym.name, inferred_type) } - pattern_sym := c.table.final_sym(pattern_typ) - actual_sym := c.table.final_sym(actual_typ) - match pattern_sym.info { + if template_type.nr_muls() != actual_type.nr_muls() { + return false + } + template_type = template_type.clear_ref() + actual_type = actual_type.clear_ref() + template_final_sym := c.table.final_sym(template_type) + actual_final_sym := c.table.final_sym(actual_type) + match template_final_sym.info { ast.Array { - if actual_sym.info is ast.Array { - return c.infer_fn_value_generic_type(pattern_sym.info.elem_type, actual_sym.info.elem_type, - generic_names, mut inferred) + if actual_final_sym.info !is ast.Array { + return false } + template_array := template_final_sym.info as ast.Array + actual_array := actual_final_sym.info as ast.Array + if template_array.nr_dims != actual_array.nr_dims { + return false + } + return c.infer_fn_value_concrete_type(mut inferred, generic_names, template_array.elem_type, + actual_array.elem_type) } ast.ArrayFixed { - if actual_sym.info is ast.ArrayFixed { - return c.infer_fn_value_generic_type(pattern_sym.info.elem_type, actual_sym.info.elem_type, - generic_names, mut inferred) + if actual_final_sym.info !is ast.ArrayFixed { + return false + } + template_array_fixed := template_final_sym.info as ast.ArrayFixed + actual_array_fixed := actual_final_sym.info as ast.ArrayFixed + if template_array_fixed.size != actual_array_fixed.size { + return false } + return c.infer_fn_value_concrete_type(mut inferred, generic_names, template_array_fixed.elem_type, + actual_array_fixed.elem_type) + } + ast.Chan { + if actual_final_sym.info !is ast.Chan { + return false + } + template_chan := template_final_sym.info as ast.Chan + actual_chan := actual_final_sym.info as ast.Chan + return c.infer_fn_value_concrete_type(mut inferred, generic_names, template_chan.elem_type, + actual_chan.elem_type) } ast.Map { - if actual_sym.info is ast.Map { - return - c.infer_fn_value_generic_type(pattern_sym.info.key_type, actual_sym.info.key_type, generic_names, mut inferred) - && c.infer_fn_value_generic_type(pattern_sym.info.value_type, actual_sym.info.value_type, generic_names, mut inferred) + if actual_final_sym.info !is ast.Map { + return false + } + template_map := template_final_sym.info as ast.Map + actual_map := actual_final_sym.info as ast.Map + return + c.infer_fn_value_concrete_type(mut inferred, generic_names, template_map.key_type, actual_map.key_type) + && c.infer_fn_value_concrete_type(mut inferred, generic_names, template_map.value_type, actual_map.value_type) + } + ast.Thread { + if actual_final_sym.info !is ast.Thread { + return false } + template_thread := template_final_sym.info as ast.Thread + actual_thread := actual_final_sym.info as ast.Thread + return c.infer_fn_value_concrete_type(mut inferred, generic_names, template_thread.return_type, + actual_thread.return_type) } ast.FnType { - if actual_sym.info is ast.FnType { - pattern_fn := pattern_sym.info.func - actual_fn := actual_sym.info.func - if pattern_fn.params.len != actual_fn.params.len { + if actual_final_sym.info !is ast.FnType { + return false + } + template_fn := (template_final_sym.info as ast.FnType).func + actual_fn := (actual_final_sym.info as ast.FnType).func + if template_fn.params.len != actual_fn.params.len + || template_fn.is_variadic != actual_fn.is_variadic { + return false + } + for i, template_param in template_fn.params { + actual_param := actual_fn.params[i] + if template_param.is_mut != actual_param.is_mut { return false } - if !c.infer_fn_value_generic_type(pattern_fn.return_type, actual_fn.return_type, - generic_names, mut inferred) { + if !c.infer_fn_value_concrete_type(mut inferred, generic_names, template_param.typ, + actual_param.typ) { return false } - for i, pattern_param in pattern_fn.params { - actual_param := actual_fn.params[i] - if pattern_param.is_mut != actual_param.is_mut { - return false - } - if !c.infer_fn_value_generic_type(pattern_param.typ, actual_param.typ, - generic_names, mut inferred) { - return false - } - } } + return c.infer_fn_value_concrete_type(mut inferred, generic_names, template_fn.return_type, + actual_fn.return_type) + } + else { + return c.table.unaliased_type(template_type) == c.table.unaliased_type(actual_type) } - else {} } - return true } -fn (mut c Checker) infer_fn_value_concrete_types(func &ast.Fn, expected_type ast.Type) []ast.Type { - if func.generic_names.len == 0 || expected_type in [0, ast.void_type] { - return []ast.Type{} +fn (mut c Checker) infer_fn_value_concrete_types(func &ast.Fn, expected_type ast.Type) ?[]ast.Type { + if expected_type in [0, ast.void_type] { + return none } - expected_typ := c.unwrap_generic(expected_type) - expected_sym := c.table.final_sym(expected_typ) - if expected_sym.info !is ast.FnType { - return []ast.Type{} + expected_sym := c.table.final_sym(expected_type) + if expected_sym.kind != .function || expected_sym.info !is ast.FnType { + return none } expected_fn := (expected_sym.info as ast.FnType).func - if func.params.len != expected_fn.params.len { - return []ast.Type{} + if func.params.len != expected_fn.params.len || func.is_variadic != expected_fn.is_variadic { + return none } mut inferred := map[string]ast.Type{} - if !c.infer_fn_value_generic_type(func.return_type, expected_fn.return_type, func.generic_names, mut - inferred) { - return []ast.Type{} - } for i, param in func.params { expected_param := expected_fn.params[i] if param.is_mut != expected_param.is_mut { - return []ast.Type{} + return none } - if !c.infer_fn_value_generic_type(param.typ, expected_param.typ, func.generic_names, mut - inferred) { - return []ast.Type{} + if !c.infer_fn_value_concrete_type(mut inferred, func.generic_names, param.typ, + expected_param.typ) { + return none } } - if inferred.len != func.generic_names.len { - return []ast.Type{} + if !c.infer_fn_value_concrete_type(mut inferred, func.generic_names, func.return_type, + expected_fn.return_type) { + return none } mut concrete_types := []ast.Type{cap: func.generic_names.len} for generic_name in func.generic_names { if generic_name !in inferred { - return []ast.Type{} + return none } concrete_types << inferred[generic_name] } return concrete_types } +fn (mut c Checker) infer_ident_fn_value_concrete_types(func &ast.Fn, mut node ast.Ident) { + if func.generic_names.len == 0 || node.concrete_types.len > 0 { + return + } + if concrete_types := c.infer_fn_value_concrete_types(func, c.expected_type) { + node.concrete_types = concrete_types + } +} + fn (mut c Checker) resolve_var_fn(func &ast.Fn, mut node ast.Ident, name string) ast.Type { if func.generic_names.len > 0 && node.concrete_types.len == 0 { - node.concrete_types = c.infer_fn_value_concrete_types(func, c.expected_type) + c.infer_ident_fn_value_concrete_types(func, mut node) } mut fn_type := c.table.find_or_register_fn_type(func, false, true) if fn_type < 0 { @@ -5585,20 +5663,18 @@ fn (mut c Checker) ident(mut node ast.Ident) ast.Type { } return typ } else if node.kind == .function { - info := node.info as ast.IdentFn if func := c.table.find_fn(node.name) { + c.infer_ident_fn_value_concrete_types(func, mut node) if func.generic_names.len > 0 { - if node.concrete_types.len == 0 { - node.concrete_types = c.infer_fn_value_concrete_types(func, c.expected_type) - } if node.concrete_types.len == 0 { c.error('`${node.name}` is a generic fn, you should pass its concrete types, e.g. ${node.name}[int]', node.pos) } return c.resolve_var_fn(func, mut node, node.name) } + return c.resolve_var_fn(func, mut node, node.name) } - return info.typ + return (node.info as ast.IdentFn).typ } else if node.kind == .unresolved { // first use if node.tok_kind == .assign && node.is_mut { @@ -5818,10 +5894,8 @@ fn (mut c Checker) ident(mut node ast.Ident) ast.Type { } // Non-anon-function object (not a call), e.g. `onclick(my_click)` if func := c.table.find_fn(name) { + c.infer_ident_fn_value_concrete_types(func, mut node) if func.generic_names.len > 0 { - if node.concrete_types.len == 0 { - node.concrete_types = c.infer_fn_value_concrete_types(func, c.expected_type) - } if node.concrete_types.len == 0 { c.error('`${node.name}` is a generic fn, you should pass its concrete types, e.g. ${node.name}[int]', node.pos) diff --git a/vlib/v/tests/generics/generic_function_argument_inference_test.v b/vlib/v/tests/generics/generic_function_argument_inference_test.v new file mode 100644 index 000000000..652ceffaf --- /dev/null +++ b/vlib/v/tests/generics/generic_function_argument_inference_test.v @@ -0,0 +1,26 @@ +import math + +struct Vec5 { + x f64 + y f64 + z f64 + a f64 + b f64 +} + +fn vec5_from(value f64) Vec5 { + return Vec5{value, value, value, value, value} +} + +fn (v Vec5) abs_oldstyle() Vec5 { + return Vec5{math.abs(v.x), math.abs(v.y), math.abs(v.z), math.abs(v.a), math.abs(v.b)} +} + +fn (v Vec5) generic_new(f fn (f64) f64) Vec5 { + return Vec5{f(v.x), f(v.y), f(v.z), f(v.a), f(v.b)} +} + +fn test_generic_function_argument_inference() { + v := vec5_from(-1.0) + assert v.generic_new(math.abs) == v.abs_oldstyle() +} -- 2.39.5