From 6d3fc3d1461ce249c90765d44e6baab80af6d1f9 Mon Sep 17 00:00:00 2001 From: Alexander Medvednikov Date: Tue, 21 Apr 2026 16:33:07 +0300 Subject: [PATCH] checker: fix stacking multiple types producing error in match smartcast (fixes #25418) --- vlib/v/ast/table.v | 104 +++++++++++++++++- vlib/v/checker/fn.v | 15 +++ vlib/v/checker/match.v | 22 ++-- vlib/v/gen/c/cgen.v | 33 +++--- .../conditions/matches/match_smartcast_test.v | 44 ++++++++ 5 files changed, 195 insertions(+), 23 deletions(-) diff --git a/vlib/v/ast/table.v b/vlib/v/ast/table.v index 83ba63308..9bb330246 100644 --- a/vlib/v/ast/table.v +++ b/vlib/v/ast/table.v @@ -1936,6 +1936,107 @@ pub fn (t &Table) sumtype_has_variant(parent Type, variant Type, is_as bool) boo return false } +pub fn (t &Table) sumtype_has_variant_recursive(parent Type, variant Type, is_as bool) bool { + if t.sumtype_has_variant(parent, variant, is_as) { + return true + } + parent_sym := t.sym(parent) + if parent_sym.kind != .sum_type || parent_sym.info !is SumType { + return false + } + parent_info := parent_sym.info as SumType + for parent_variant in parent_info.variants { + if nested_sumtype := t.sumtype_nested_variant_type(parent_variant) { + if t.sumtype_has_variant_recursive(nested_sumtype, variant, is_as) { + return true + } + } + } + return false +} + +pub fn (t &Table) sumtype_matchable_variants(parent Type) []Type { + mut variants := []Type{} + mut seen := map[u32]bool{} + t.collect_sumtype_matchable_variants(parent, mut seen, mut variants) + return variants +} + +pub fn (t &Table) sumtype_missing_variants(parent Type, handled []Type) []Type { + mut missing := []Type{} + mut seen := map[u32]bool{} + t.collect_sumtype_missing_variants(parent, handled, mut seen, mut missing) + return missing +} + +fn (t &Table) collect_sumtype_matchable_variants(parent Type, mut seen map[u32]bool, mut variants []Type) { + parent_sym := t.sym(parent) + if parent_sym.kind != .sum_type || parent_sym.info !is SumType { + return + } + parent_info := parent_sym.info as SumType + for variant in parent_info.variants { + if u32(variant) !in seen { + seen[u32(variant)] = true + variants << variant + } + if nested_sumtype := t.sumtype_nested_variant_type(variant) { + t.collect_sumtype_matchable_variants(nested_sumtype, mut seen, mut variants) + } + } +} + +fn (t &Table) collect_sumtype_missing_variants(parent Type, handled []Type, mut seen map[u32]bool, mut missing []Type) { + if t.sumtype_variant_is_handled(parent, handled) { + return + } + if nested_sumtype := t.sumtype_nested_variant_type(parent) { + nested_sym := t.sym(nested_sumtype) + if nested_sym.kind == .sum_type && nested_sym.info is SumType { + nested_info := nested_sym.info as SumType + for variant in nested_info.variants { + if t.sumtype_variant_is_handled(variant, handled) { + continue + } + if nested_variant := t.sumtype_nested_variant_type(variant) { + t.collect_sumtype_missing_variants(nested_variant, handled, mut seen, mut + missing) + } else if u32(variant) !in seen { + seen[u32(variant)] = true + missing << variant + } + } + return + } + } + if u32(parent) !in seen { + seen[u32(parent)] = true + missing << parent + } +} + +fn (t &Table) sumtype_variant_is_handled(variant Type, handled []Type) bool { + for handled_variant in handled { + if t.same_sumtype_variant(variant, handled_variant, true) { + return true + } + } + return false +} + +fn (t &Table) same_sumtype_variant(expected Type, got Type, is_as bool) bool { + return expected.idx() == got.idx() && expected.has_flag(.option) == got.has_flag(.option) + && (!is_as || expected.nr_muls() == got.nr_muls()) +} + +fn (t &Table) sumtype_nested_variant_type(variant Type) ?Type { + nested_sumtype := t.fully_unaliased_type(variant) + if t.sym(nested_sumtype).kind == .sum_type { + return nested_sumtype + } + return none +} + fn (t &Table) sumtype_check_function_variant(parent_info SumType, variant Type, is_as bool) bool { variant_fn := (t.sym(variant).info as FnType).func variant_fn_sig := t.fn_type_source_signature(variant_fn) @@ -1954,8 +2055,7 @@ fn (t &Table) sumtype_check_function_variant(parent_info SumType, variant Type, fn (t &Table) sumtype_check_variant_in_type(parent_info SumType, variant Type, is_as bool) bool { for v in parent_info.variants { - if v.idx() == variant.idx() && variant.has_flag(.option) == v.has_flag(.option) - && (!is_as || v.nr_muls() == variant.nr_muls()) { + if t.same_sumtype_variant(v, variant, is_as) { return true } } diff --git a/vlib/v/checker/fn.v b/vlib/v/checker/fn.v index bf5e31bd1..2dba69b5e 100644 --- a/vlib/v/checker/fn.v +++ b/vlib/v/checker/fn.v @@ -2314,6 +2314,21 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast. } } } + if param_type_sym.kind == .sum_type + && !c.table.sumtype_has_variant(param.typ, arg_typ, false) + && c.table.sumtype_has_variant_recursive(param.typ, arg_typ, false) { + call_arg.expr = ast.CastExpr{ + expr: call_arg.expr + typ: param.typ + typname: c.table.type_to_str(param.typ) + expr_type: arg_typ + pos: call_arg.expr.pos() + } + call_arg.typ = param.typ + node.args[i].expr = call_arg.expr + node.args[i].typ = param.typ + arg_typ = param.typ + } arg_typ_sym := c.table.sym(arg_typ) if param.typ.has_flag(.generic) { if arg_typ_sym.kind == .none && !param.typ.has_flag(.option) { diff --git a/vlib/v/checker/match.v b/vlib/v/checker/match.v index 6749e2b13..9e76a8d28 100644 --- a/vlib/v/checker/match.v +++ b/vlib/v/checker/match.v @@ -694,6 +694,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, cond_type_sym ast.TypeSym // branch_exprs is a histogram of how many times // an expr was used in the match mut branch_exprs := map[string]int{} + mut branch_expr_types := []ast.Type{} is_multi_allowed_enum_match := cond_type_sym.info is ast.Enum && cond_type_sym.info.is_multi_allowed mut branch_enum_values := map[i64]bool{} @@ -899,11 +900,11 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, cond_type_sym ast.TypeSym } } } else if cond_match_sym.info is ast.SumType { - if expr_type !in cond_match_sym.info.variants { + if !c.table.sumtype_has_variant_recursive(cond_match_type, expr_type, true) { expr_str := c.table.type_to_str(expr_type) expect_str := c.table.type_to_str(node.cond_type) sumtype_variant_names := - cond_match_sym.info.variants.map(c.table.type_to_str_using_aliases(it, {})) + c.table.sumtype_matchable_variants(cond_match_type).map(c.table.type_to_str_using_aliases(it, {})) suggestion := util.new_suggestion(expr_str, sumtype_variant_names) c.error(suggestion.say('`${expect_str}` has no variant `${expr_str}`'), expr.pos()) @@ -918,6 +919,14 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, cond_type_sym ast.TypeSym expect_str := c.table.type_to_str(node.cond_type) c.error('cannot match `${expect_str}` with `${expr_str}`', expr.pos()) } + if is_type_node { + branch_expr_types << if is_alias_to_matchable_type + && expr_type == node.cond_type { + cond_match_type + } else { + expr_type + } + } } branch_exprs[key] = val + 1 } @@ -983,12 +992,9 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, cond_type_sym ast.TypeSym } else { match cond_match_sym.info { ast.SumType { - for v in cond_match_sym.info.variants { - v_str := c.table.type_to_str(v) - if v_str !in branch_exprs { - is_exhaustive = false - unhandled << '`${v_str}`' - } + for v in c.table.sumtype_missing_variants(cond_match_type, branch_expr_types) { + is_exhaustive = false + unhandled << '`${c.table.type_to_str(v)}`' } } // diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index ffdd92f48..ff4603101 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -1295,6 +1295,11 @@ pub fn (mut g Gen) get_sumtype_variant_name(typ ast.Type, sym ast.TypeSymbol) st return if typ.has_flag(.option) { '_option_${sym.cname}' } else { sym.cname } } +@[inline] +fn (g &Gen) sumtype_runtime_variants(typ ast.Type) []ast.Type { + return g.table.sumtype_matchable_variants(typ.idx_type()) +} + // get_sumtype_casting_variant_name returns a helper-safe variant name that // keeps pointer-depth distinctions for sumtype cast wrappers. @[inline] @@ -1326,9 +1331,10 @@ pub fn (mut g Gen) write_typeof_functions() { } g.writeln('${static_prefix}char * v_typeof_sumtype_${sym.cname}(u32 sidx) {') g.definitions.writeln('${static_prefix}char * v_typeof_sumtype_${sym.cname}(u32);') + runtime_variants := g.sumtype_runtime_variants(ast.idx_to_type(ityp)) if g.pref.build_mode == .build_module { g.writeln('\t\tif( sidx == _v_type_idx_${sym.cname}() ) return "${util.strip_main_name(sym.name)}";') - for v in sum_info.variants { + for v in runtime_variants { subtype := g.table.sym(v) g.writeln('\tif( sidx == _v_type_idx_${g.get_sumtype_variant_name(v, subtype)}() ) return "${util.strip_main_name(subtype.name)}";') } @@ -1338,7 +1344,7 @@ pub fn (mut g Gen) write_typeof_functions() { g.writeln('\tswitch(sidx) {') g.writeln('\t\tcase ${tidx}: return "${util.strip_main_name(sym.name)}";') mut idxs := []ast.Type{} - for v in sum_info.variants { + for v in runtime_variants { if v in idxs { continue } @@ -1353,7 +1359,7 @@ pub fn (mut g Gen) write_typeof_functions() { g.writeln('${static_prefix}u32 v_typeof_sumtype_idx_${sym.cname}(u32 sidx) {') if g.pref.build_mode == .build_module { g.writeln('\t\tif( sidx == _v_type_idx_${sym.cname}() ) return ${u32(ityp)};') - for v in sum_info.variants { + for v in runtime_variants { subtype := g.table.sym(v) g.writeln('\tif( sidx == _v_type_idx_${subtype.cname}() ) return ${u32(v)};') } @@ -1362,7 +1368,7 @@ pub fn (mut g Gen) write_typeof_functions() { tidx := g.table.find_type_idx(sym.name) g.writeln2('\tswitch(sidx) {', '\t\tcase ${tidx}: return ${u32(ityp)};') mut idxs := []ast.Type{} - for v in sum_info.variants { + for v in runtime_variants { if v in idxs { continue } @@ -3834,8 +3840,7 @@ fn (mut g Gen) write_sumtype_casting_fn(fun SumtypeCastingFn) { got_name := 'fn ${g.table.fn_type_source_signature(got_sym.info.func)}' got_cname = 'anon_fn_${g.table.fn_type_signature(got_sym.info.func)}' type_idx = g.table.type_idxs[got_name].str() - exp_info := exp_sym.info as ast.SumType - for variant in exp_info.variants { + for variant in g.sumtype_runtime_variants(exp) { variant_sym := g.table.sym(variant) if variant_sym.info is ast.FnType { if g.table.fn_type_source_signature(variant_sym.info.func) == g.table.fn_type_source_signature(got_sym.info.func) { @@ -4157,7 +4162,7 @@ fn (g &Gen) single_pointer_sumtype_nil_variant(expected_type ast.Type, expr ast. return 0 } mut variant := ast.Type(0) - for sumtype_variant in (expected_sym.info as ast.SumType).variants { + for sumtype_variant in g.sumtype_runtime_variants(expected_type) { if g.table.unaliased_type(sumtype_variant).is_any_kind_of_pointer() { if variant != 0 { return 0 @@ -4187,7 +4192,7 @@ fn (g &Gen) find_matching_sumtype_variant(expected_type ast.Type, got_type ast.T if expected_sym.kind != .sum_type { return got_type } - variants := (expected_sym.info as ast.SumType).variants + variants := g.sumtype_runtime_variants(expected_type) for variant in variants { if g.is_exact_sumtype_variant_match(variant, got_type) { return variant @@ -4608,7 +4613,8 @@ fn (mut g Gen) expr_with_cast(expr ast.Expr, got_type_raw ast.Type, expected_typ sumtype_got_type } if nil_sumtype_variant != 0 - || g.table.sumtype_has_variant(expected_deref_type, got_deref_type, false) { + || g.table.sumtype_has_variant(expected_deref_type, got_deref_type, false) + || g.table.sumtype_has_variant_recursive(expected_deref_type, got_deref_type, false) { mut is_already_sum_type := false scope := g.file.scope.innermost(expr.pos().pos) if expr is ast.Ident { @@ -10151,11 +10157,12 @@ fn (mut g Gen) write_types(symbols []&ast.TypeSymbol) { struct_names[name] = true g.typedefs.writeln('typedef struct ${name} ${name};') mut idxs := []ast.Type{} + runtime_variants := g.sumtype_runtime_variants(ast.idx_to_type(sym.idx)) if !g.pref.is_prod { // Do not print union sum type coment in prod mode g.type_definitions.writeln('') g.type_definitions.writeln('// Union sum type ${name} = ') - for variant in sym.info.variants { + for variant in runtime_variants { if variant in idxs { continue } @@ -10166,7 +10173,7 @@ fn (mut g Gen) write_types(symbols []&ast.TypeSymbol) { } g.type_definitions.writeln('struct ${name} {') g.type_definitions.writeln('\tunion {') - for variant in sym.info.variants { + for variant in runtime_variants { if variant in idxs { continue } @@ -10372,7 +10379,7 @@ fn (mut g Gen) sort_structs(typesa []&ast.TypeSymbol) []&ast.TypeSymbol { } } ast.SumType { - for variant in sym.info.variants { + for variant in g.sumtype_runtime_variants(ast.idx_to_type(sym.idx)) { vsym := g.table.sym(variant) if vsym.info !is ast.Struct { continue @@ -11277,7 +11284,7 @@ fn (mut g Gen) as_cast(node ast.AsCast) { } // fill as cast name table - for variant in expr_type_sym.info.variants { + for variant in g.sumtype_runtime_variants(node.expr_type) { idx := u32(variant).str() if idx in g.as_cast_type_names { continue diff --git a/vlib/v/tests/conditions/matches/match_smartcast_test.v b/vlib/v/tests/conditions/matches/match_smartcast_test.v index 5e4ecdd71..2b2e97252 100644 --- a/vlib/v/tests/conditions/matches/match_smartcast_test.v +++ b/vlib/v/tests/conditions/matches/match_smartcast_test.v @@ -117,3 +117,47 @@ fn test_branches_return_struct_field() { } assert m['item2']! == Any_1(42) } + +struct NestedMatchA { + a int +} + +struct NestedMatchB { + b int +} + +struct NestedMatchC { + c int +} + +type NestedMatchD = NestedMatchA | NestedMatchB +type NestedMatchE = NestedMatchC | NestedMatchD + +fn describe_nested_match(e NestedMatchE) string { + return match e { + NestedMatchA { 'A:${e.a}' } + NestedMatchB { 'B:${e.b}' } + NestedMatchC { 'C:${e.c}' } + } +} + +fn pass_nested_match(e NestedMatchE) NestedMatchE { + return e +} + +fn test_nested_sumtype_leaf_match_and_argument() { + a := NestedMatchA{ + a: 1 + } + b := NestedMatchB{ + b: 2 + } + c := NestedMatchC{ + c: 3 + } + + assert describe_nested_match(a) == 'A:1' + assert describe_nested_match(b) == 'B:2' + assert describe_nested_match(c) == 'C:3' + assert describe_nested_match(pass_nested_match(a)) == 'A:1' +} -- 2.39.5