From 72f6f681059c7e1916c60286f022d663af1d45d9 Mon Sep 17 00:00:00 2001 From: GGRei Date: Sun, 3 May 2026 16:34:31 +0200 Subject: [PATCH] cgen: resolve generic sumtype match method calls (#27068) --- vlib/v/gen/c/cgen.v | 4 + vlib/v/gen/c/fn.v | 37 ++++-- vlib/v/gen/c/match.v | 30 ++++- vlib/v/gen/c/utils.v | 37 ++++++ .../tests/generics/generic_match_expr_test.v | 111 ++++++++++++++++++ 5 files changed, 204 insertions(+), 15 deletions(-) diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index a3bbae66c..7cc592617 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -9005,6 +9005,10 @@ fn (mut g Gen) ident(node ast.Ident) { smartcast_types = [unwrapped_option_type] } } + match_variant_type := g.current_sumtype_match_variant_type(node, node.info.typ) + if match_variant_type != 0 { + smartcast_types = [match_variant_type] + } interface_source_is_interface := g.table.final_sym(g.unwrap_generic(resolved_var.typ)).kind == .interface || (resolved_var.orig_type != 0 diff --git a/vlib/v/gen/c/fn.v b/vlib/v/gen/c/fn.v index 6d7f984ec..c989b66f4 100644 --- a/vlib/v/gen/c/fn.v +++ b/vlib/v/gen/c/fn.v @@ -4229,10 +4229,16 @@ fn (mut g Gen) resolve_current_fn_generic_param_key_type(name string) ast.Type { fn (mut g Gen) unwrap_receiver_type(node ast.CallExpr) (ast.Type, &ast.TypeSymbol) { mut left_type := g.unwrap_generic(node.left_type) + mut match_variant_type := ast.Type(0) if node.left is ast.Ident { - resolved_left_type := g.resolve_current_fn_generic_param_type(node.left.name) - if resolved_left_type != 0 { - left_type = resolved_left_type + match_variant_type = g.current_sumtype_match_variant_type(node.left, node.left_type) + if match_variant_type != 0 { + left_type = match_variant_type + } else { + resolved_left_type := g.resolve_current_fn_generic_param_type(node.left.name) + if resolved_left_type != 0 { + left_type = resolved_left_type + } } } else if node.left is ast.StructInit { if g.cur_fn != unsafe { nil } && g.cur_concrete_types.len > 0 { @@ -4291,7 +4297,9 @@ fn (mut g Gen) unwrap_receiver_type(node ast.CallExpr) (ast.Type, &ast.TypeSymbo } } if node.from_embed_types.len == 0 && node.left is ast.Ident { - if node.left.obj is ast.Var { + if match_variant_type != 0 { + unwrapped_rec_type = match_variant_type + } else if node.left.obj is ast.Var { if node.left.obj.smartcasts.len > 0 { if node.left.obj.ct_type_var == .smartcast { unwrapped_rec_type = @@ -4403,14 +4411,19 @@ fn (mut g Gen) method_call(node ast.CallExpr) { mut receiver_type := node.receiver_type match node.left { ast.Ident { - resolved_left_type := g.resolve_current_fn_generic_param_type(node.left.name) - if resolved_left_type != 0 { - left_type = resolved_left_type - } else if g.cur_fn != unsafe { nil } && g.cur_concrete_types.len > 0 { - scope_type := g.resolved_scope_var_type(node.left) - if scope_type != 0 && !scope_type.has_flag(.generic) - && !g.type_has_unresolved_generic_parts(scope_type) { - left_type = scope_type + match_variant_type := g.current_sumtype_match_variant_type(node.left, node.left_type) + if match_variant_type != 0 { + left_type = match_variant_type + } else { + resolved_left_type := g.resolve_current_fn_generic_param_type(node.left.name) + if resolved_left_type != 0 { + left_type = resolved_left_type + } else if g.cur_fn != unsafe { nil } && g.cur_concrete_types.len > 0 { + scope_type := g.resolved_scope_var_type(node.left) + if scope_type != 0 && !scope_type.has_flag(.generic) + && !g.type_has_unresolved_generic_parts(scope_type) { + left_type = scope_type + } } } } diff --git a/vlib/v/gen/c/match.v b/vlib/v/gen/c/match.v index 3206c1c59..ceb4d27b2 100644 --- a/vlib/v/gen/c/match.v +++ b/vlib/v/gen/c/match.v @@ -193,6 +193,16 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str for { g.aggregate_type_idx = sumtype_index is_last := j == node.branches.len - 1 && sumtype_index == branch.exprs.len - 1 + mut has_branch_type := false + mut had_old_branch_type := false + mut old_branch_type := ast.Type(0) + mut branch_type := ast.Type(0) + if cond_sym.kind == .sum_type && sumtype_index < branch.exprs.len { + branch_expr := unsafe { &branch.exprs[sumtype_index] } + if branch_expr is ast.TypeNode { + branch_type = branch_expr.typ + } + } if branch.is_else || (use_ternary && is_last) { if use_ternary { // TODO: too many branches. maybe separate ?: matches @@ -230,9 +240,6 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str } cur_expr := unsafe { &branch.exprs[sumtype_index] } if cond_sym.kind == .sum_type { - if cur_expr is ast.TypeNode { - g.type_resolver.update_ct_type(cond_var, cur_expr.typ) - } g.write('${dot_or_ptr}_typ == ') if cur_expr is ast.None { g.write('${ast.none_type.idx()} /* none */') @@ -253,6 +260,15 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str g.writeln(') {') } } + if branch_type != 0 { + has_branch_type = true + if old_type := g.type_resolver.type_map[cond_var] { + had_old_branch_type = true + old_branch_type = old_type + } + g.type_resolver.update_ct_type(cond_var, branch_type) + g.clear_type_resolution_caches() + } if is_expr && tmp_var.len > 0 && g.table.sym(resolved_return_type).kind in [.sum_type, .interface] { g.expected_cast_type = resolved_return_type @@ -276,6 +292,14 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str g.write_defer_stmts(branch.scope, false, node.pos) g.inside_interface_deref = inside_interface_deref_old g.expected_cast_type = 0 + if has_branch_type { + if had_old_branch_type { + g.type_resolver.update_ct_type(cond_var, old_branch_type) + } else { + g.type_resolver.type_map.delete(cond_var) + } + g.clear_type_resolution_caches() + } if g.inside_ternary == 0 { g.writeln('}') g.set_current_pos_as_last_stmt_pos() diff --git a/vlib/v/gen/c/utils.v b/vlib/v/gen/c/utils.v index 1401dee00..bc1bb3d3c 100644 --- a/vlib/v/gen/c/utils.v +++ b/vlib/v/gen/c/utils.v @@ -20,6 +20,43 @@ fn (mut g Gen) clear_type_resolution_caches() { g.resolved_scope_var_type_cache.clear() } +fn (mut g Gen) current_sumtype_match_variant_type(ident ast.Ident, sumtype_type ast.Type) ast.Type { + if g.cur_fn == unsafe { nil } || g.cur_concrete_types.len == 0 { + return ast.Type(0) + } + mut branch_type := ast.Type(0) + if typ := g.type_resolver.type_map[ident.name] { + branch_type = typ + } + if branch_type == 0 { + cname := c_name(ident.name) + if cname != ident.name { + if typ := g.type_resolver.type_map[cname] { + branch_type = typ + } + } + } + if branch_type == 0 || branch_type == ast.void_type { + return ast.Type(0) + } + variant_type := g.unwrap_generic(g.recheck_concrete_type(branch_type)) + if variant_type == 0 || variant_type == ast.void_type || variant_type.has_flag(.generic) + || g.type_has_unresolved_generic_parts(variant_type) { + return ast.Type(0) + } + mut parent_type := g.unwrap_generic(g.recheck_concrete_type(sumtype_type)).set_nr_muls(0) + if parent_type != 0 && g.table.final_sym(parent_type).kind == .sum_type + && g.table.sumtype_has_variant(parent_type, variant_type, false) { + return variant_type + } + parent_type = g.resolve_current_fn_generic_param_type(ident.name).set_nr_muls(0) + if parent_type == 0 || g.table.final_sym(parent_type).kind != .sum_type + || !g.table.sumtype_has_variant(parent_type, variant_type, false) { + return ast.Type(0) + } + return variant_type +} + fn (g &Gen) type_resolution_context_key() u64 { mut key := cgen_resolution_hash_seed if g.inside_struct_init { diff --git a/vlib/v/tests/generics/generic_match_expr_test.v b/vlib/v/tests/generics/generic_match_expr_test.v index 303404489..9cec5897c 100644 --- a/vlib/v/tests/generics/generic_match_expr_test.v +++ b/vlib/v/tests/generics/generic_match_expr_test.v @@ -60,3 +60,114 @@ fn test_match_generic_sumtype_variant_in_generic_method() { mut x2 := GenericMatchSumtype[bool](z) assert x2.do_it() == 'doing GenericMatchZ' } + +type MatchedGenericMethodBox[T] = MatchedGenericMethodY[T] | MatchedGenericMethodZ[T] + +struct MatchedGenericMethodY[T] { + value T +} + +struct MatchedGenericMethodZ[T] { + value T +} + +fn (x MatchedGenericMethodBox[T]) explicit_variant_name[T]() string { + match x { + MatchedGenericMethodY[T] { return x.explicit_variant_name[T]() } + MatchedGenericMethodZ[T] { return x.explicit_variant_name[T]() } + } + + return '' +} + +fn (y MatchedGenericMethodY[T]) explicit_variant_name[T]() string { + return 'Y' +} + +fn (z MatchedGenericMethodZ[T]) explicit_variant_name[T]() string { + return 'Z' +} + +fn (x MatchedGenericMethodBox[T]) inferred_variant_name[T]() string { + match x { + MatchedGenericMethodY[T] { return x.inferred_variant_name() } + MatchedGenericMethodZ[T] { return x.inferred_variant_name() } + } + + return '' +} + +fn (x MatchedGenericMethodBox[T]) match_expr_variant_name[T]() string { + return match x { + MatchedGenericMethodY[T] { x.explicit_variant_name[T]() } + MatchedGenericMethodZ[T] { x.explicit_variant_name[T]() } + } +} + +fn (y MatchedGenericMethodY[T]) inferred_variant_name[T]() string { + return 'Y' +} + +fn (z MatchedGenericMethodZ[T]) inferred_variant_name[T]() string { + return 'Z' +} + +fn test_generic_sumtype_match_calls_explicit_generic_method_on_variant() { + y := MatchedGenericMethodBox[int](MatchedGenericMethodY[int]{ + value: 1 + }) + assert y.explicit_variant_name() == 'Y' + + z := MatchedGenericMethodBox[bool](MatchedGenericMethodZ[bool]{ + value: true + }) + assert z.explicit_variant_name() == 'Z' +} + +fn test_generic_sumtype_match_calls_explicit_generic_method_on_variant_reversed() { + z := MatchedGenericMethodBox[bool](MatchedGenericMethodZ[bool]{ + value: true + }) + assert z.explicit_variant_name() == 'Z' + + y := MatchedGenericMethodBox[int](MatchedGenericMethodY[int]{ + value: 1 + }) + assert y.explicit_variant_name() == 'Y' +} + +fn test_generic_sumtype_match_calls_explicit_generic_method_single_instantiation() { + y := MatchedGenericMethodBox[string](MatchedGenericMethodY[string]{ + value: 'one' + }) + assert y.explicit_variant_name() == 'Y' + + z := MatchedGenericMethodBox[string](MatchedGenericMethodZ[string]{ + value: 'two' + }) + assert z.explicit_variant_name() == 'Z' +} + +fn test_generic_sumtype_match_calls_inferred_generic_method_on_variant() { + z := MatchedGenericMethodBox[int](MatchedGenericMethodZ[int]{ + value: 1 + }) + assert z.inferred_variant_name() == 'Z' + + y := MatchedGenericMethodBox[bool](MatchedGenericMethodY[bool]{ + value: true + }) + assert y.inferred_variant_name() == 'Y' +} + +fn test_generic_sumtype_match_expr_calls_generic_method_on_last_variant() { + y := MatchedGenericMethodBox[int](MatchedGenericMethodY[int]{ + value: 1 + }) + assert y.match_expr_variant_name() == 'Y' + + z := MatchedGenericMethodBox[bool](MatchedGenericMethodZ[bool]{ + value: true + }) + assert z.match_expr_variant_name() == 'Z' +} -- 2.39.5