From 7df7f31741cef44547bd193ed5b04644f672ff21 Mon Sep 17 00:00:00 2001 From: Alexander Medvednikov Date: Wed, 25 Mar 2026 16:42:23 +0300 Subject: [PATCH] checker: fix cannot return struct implementing interface with return match (fixes #24148) --- vlib/v/checker/match.v | 132 ++++++------------ vlib/v/gen/c/match.v | 11 +- .../match_expr_returning_interface_test.v | 32 +++++ 3 files changed, 82 insertions(+), 93 deletions(-) create mode 100644 vlib/v/tests/conditions/matches/match_expr_returning_interface_test.v diff --git a/vlib/v/checker/match.v b/vlib/v/checker/match.v index 14cdb3bae..6fe8d1488 100644 --- a/vlib/v/checker/match.v +++ b/vlib/v/checker/match.v @@ -59,7 +59,7 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { node.is_sum_type = cond_type_sym.kind in [.interface, .sum_type] c.match_exprs(mut node, cond_type_sym) c.expected_type = node.cond_type - mut ret_type_needs_inference := true + mut first_iteration := true mut infer_cast_type := ast.void_type mut need_explicit_cast := false mut ret_type := ast.void_type @@ -324,36 +324,32 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { unwrapped_expected_type := c.unwrap_generic(node.expected_type) must_be_option = must_be_option || expr_type == ast.none_type stmt.typ = expr_type - is_noreturn := is_noreturn_callexpr(stmt.expr) - if ret_type_needs_inference { - if !is_noreturn { - if unwrapped_expected_type.has_option_or_result() - || c.table.type_kind(unwrapped_expected_type) in [.sum_type, .multi_return] { - c.check_match_branch_last_stmt(stmt, unwrapped_expected_type, - expr_type) - ret_type = node.expected_type - } else { - ret_type = expr_type - if expr_type.is_ptr() { - if stmt.expr is ast.Ident && stmt.expr.obj is ast.Var - && c.table.is_interface_var(stmt.expr.obj) { + if first_iteration { + if unwrapped_expected_type.has_option_or_result() + || c.table.type_kind(unwrapped_expected_type) in [.sum_type, .interface, .multi_return] { + c.check_match_branch_last_stmt(mut stmt, unwrapped_expected_type, + expr_type) + ret_type = node.expected_type + } else { + ret_type = expr_type + if expr_type.is_ptr() { + if stmt.expr is ast.Ident && stmt.expr.obj is ast.Var + && c.table.is_interface_var(stmt.expr.obj) { + ret_type = expr_type.deref() + } else if mut stmt.expr is ast.PrefixExpr + && stmt.expr.right is ast.Ident { + ident := stmt.expr.right as ast.Ident + if ident.obj is ast.Var && c.table.is_interface_var(ident.obj) { ret_type = expr_type.deref() - } else if mut stmt.expr is ast.PrefixExpr - && stmt.expr.right is ast.Ident { - ident := stmt.expr.right as ast.Ident - if ident.obj is ast.Var && c.table.is_interface_var(ident.obj) { - ret_type = expr_type.deref() - } } } - c.expected_expr_type = expr_type - } - infer_cast_type = stmt.typ - if mut stmt.expr is ast.CastExpr { - need_explicit_cast = true - infer_cast_type = stmt.expr.typ } - ret_type_needs_inference = false + c.expected_expr_type = expr_type + } + infer_cast_type = stmt.typ + if mut stmt.expr is ast.CastExpr { + need_explicit_cast = true + infer_cast_type = stmt.expr.typ } } else { if ret_type.idx() != expr_type.idx() { @@ -370,7 +366,7 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { } stmt.typ = ast.error_type } else { - c.check_match_branch_last_stmt(stmt, c.unwrap_generic(ret_type), + c.check_match_branch_last_stmt(mut stmt, c.unwrap_generic(ret_type), expr_type) if ret_type.is_number() && expr_type.is_number() && !c.inside_return { ret_type = c.promote_num(ret_type, expr_type) @@ -380,7 +376,7 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { if must_be_option && ret_type == ast.none_type && expr_type != ret_type { ret_type = expr_type.set_flag(.option) } - if stmt.typ != ast.error_type && !is_noreturn { + if stmt.typ != ast.error_type && !is_noreturn_callexpr(stmt.expr) { ret_sym := c.table.sym(ret_type) stmt_sym := c.table.sym(stmt.typ) if ret_sym.kind !in [.sum_type, .interface] @@ -492,6 +488,9 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { ret_type = if stmt.types.len > 0 { stmt.types[0] } else { c.expected_type } } } + if !node.is_comptime || (node.is_comptime && comptime_match_branch_result) { + first_iteration = false + } if node.is_comptime { // branches may not have been processed by c.stmts() if has_top_return(branch.stmts) { @@ -545,13 +544,22 @@ fn (mut c Checker) match_expr(mut node ast.MatchExpr) ast.Type { return node.return_type } -fn (mut c Checker) check_match_branch_last_stmt(last_stmt ast.ExprStmt, ret_type ast.Type, expr_type ast.Type) { +fn (mut c Checker) check_match_branch_last_stmt(mut last_stmt ast.ExprStmt, ret_type ast.Type, expr_type ast.Type) { if !c.check_types(ret_type, expr_type) && !c.check_types(expr_type, ret_type) { ret_sym := c.table.sym(ret_type) + expr_sym := c.table.sym(expr_type) + if ret_sym.kind == .interface { + if c.type_implements(expr_type, ret_type, last_stmt.pos) { + if !expr_type.is_any_kind_of_pointer() && expr_sym.kind != .interface + && !c.inside_unsafe { + c.mark_as_referenced(mut &last_stmt.expr, true) + } + } + return + } is_noreturn := is_noreturn_callexpr(last_stmt.expr) if !(ret_sym.kind == .sum_type && (ret_type.has_flag(.generic) || c.table.is_sumtype_or_in_variant(ret_type, expr_type))) && !is_noreturn { - expr_sym := c.table.sym(expr_type) if expr_sym.kind == .multi_return && ret_sym.kind == .multi_return { ret_types := ret_sym.mr_info().types expr_types := expr_sym.mr_info().types.map(ast.mktyp(it)) @@ -570,51 +578,15 @@ fn (mut c Checker) check_match_branch_last_stmt(last_stmt ast.ExprStmt, ret_type } } -fn char_literal_number_value(value string) ?i64 { - if value.len == 2 && value[0] == `\\` { - return match value[1] { - `a` { 7 } - `b` { 8 } - `t` { 9 } - `n` { 10 } - `v` { 11 } - `f` { 12 } - `r` { 13 } - `e` { 27 } - `$` { 36 } - `"` { 34 } - `'` { 39 } - `?` { 63 } - `@` { 64 } - `\\` { 92 } - `\`` { 96 } - `{` { 123 } - `}` { 125 } - else { none } - } - } - runes := value.runes() - if runes.len == 1 { - return runes[0] - } - return none -} - fn (mut c Checker) get_comptime_number_value(mut expr ast.Expr) ?i64 { - if mut expr is ast.ParExpr { - return c.get_comptime_number_value(mut expr.expr) - } - if mut expr is ast.PrefixExpr && expr.op == .minus { - return -c.get_comptime_number_value(mut expr.right)? - } if mut expr is ast.CharLiteral { - return char_literal_number_value(expr.val) + return expr.val[0] } if mut expr is ast.IntegerLiteral { return expr.val.i64() } - if mut expr is ast.CastExpr { - return c.get_comptime_number_value(mut expr.expr) + if mut expr is ast.CastExpr && expr.expr is ast.IntegerLiteral { + return expr.expr.val.i64() } if mut expr is ast.Ident { if mut obj := c.table.global_scope.find_const(expr.full_name()) { @@ -627,24 +599,6 @@ fn (mut c Checker) get_comptime_number_value(mut expr ast.Expr) ?i64 { return none } -fn (mut c Checker) get_match_case_int_key(mut expr ast.Expr, cond_sym ast.TypeSymbol) ?string { - if !cond_sym.is_int() { - return none - } - if value := c.get_comptime_number_value(mut expr) { - return value.str() - } - if value := c.eval_comptime_const_expr(expr, 0) { - if signed_value := value.i64() { - return signed_value.str() - } - if unsigned_value := value.u64() { - return unsigned_value.str() - } - } - return none -} - fn (mut c Checker) match_exprs(mut node ast.MatchExpr, cond_type_sym ast.TypeSymbol) { c.expected_type = node.expected_type if node.cond_type.idx() == 0 { @@ -775,7 +729,7 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, cond_type_sym ast.TypeSym } } else { - key = c.get_match_case_int_key(mut expr, cond_sym) or { (*expr).str() } + key = (*expr).str() } } val := if key in branch_exprs { branch_exprs[key] } else { 0 } diff --git a/vlib/v/gen/c/match.v b/vlib/v/gen/c/match.v index a3f4345da..8aab4d1fb 100644 --- a/vlib/v/gen/c/match.v +++ b/vlib/v/gen/c/match.v @@ -18,7 +18,7 @@ fn (mut g Gen) need_tmp_var_in_match(node ast.MatchExpr) bool { if g.inside_struct_init { return true } - if g.table.sym(node.return_type).kind in [.sum_type, .multi_return] + if g.table.sym(node.return_type).kind in [.sum_type, .interface, .multi_return] || node.return_type.has_option_or_result() { return true } @@ -251,7 +251,8 @@ fn (mut g Gen) match_expr_sumtype(node ast.MatchExpr, is_expr bool, cond_var str g.writeln(') {') } } - if is_expr && tmp_var.len > 0 && g.table.sym(node.return_type).kind == .sum_type { + if is_expr && tmp_var.len > 0 + && g.table.sym(node.return_type).kind in [.sum_type, .interface] { g.expected_cast_type = node.return_type } inside_interface_deref_old := g.inside_interface_deref @@ -360,7 +361,8 @@ fn (mut g Gen) match_expr_switch(node ast.MatchExpr, is_expr bool, cond_var stri } } g.writeln('{') - if is_expr && tmp_var.len > 0 && g.table.sym(node.return_type).kind == .sum_type { + if is_expr && tmp_var.len > 0 + && g.table.sym(node.return_type).kind in [.sum_type, .interface] { g.expected_cast_type = node.return_type } ends_with_return := g.stmts_with_tmp_var(branch.stmts, tmp_var) @@ -569,7 +571,8 @@ fn (mut g Gen) match_expr_classic(node ast.MatchExpr, is_expr bool, cond_var str g.writeln(') {') } } - if is_expr && tmp_var.len > 0 && g.table.sym(node.return_type).kind == .sum_type { + if is_expr && tmp_var.len > 0 + && g.table.sym(node.return_type).kind in [.sum_type, .interface] { g.expected_cast_type = node.return_type } g.stmts_with_tmp_var(branch.stmts, tmp_var) diff --git a/vlib/v/tests/conditions/matches/match_expr_returning_interface_test.v b/vlib/v/tests/conditions/matches/match_expr_returning_interface_test.v new file mode 100644 index 000000000..7c5ccb115 --- /dev/null +++ b/vlib/v/tests/conditions/matches/match_expr_returning_interface_test.v @@ -0,0 +1,32 @@ +interface MatchExprInterface { + label() string +} + +struct MatchExprOne {} + +fn (m MatchExprOne) label() string { + return 'one' +} + +struct MatchExprTwo {} + +fn (m MatchExprTwo) label() string { + return 'two' +} + +enum MatchExprVariant { + one + two +} + +fn new_match_expr_interface(variant MatchExprVariant) MatchExprInterface { + return match variant { + .one { MatchExprOne{} } + .two { MatchExprTwo{} } + } +} + +fn test_return_match_expr_as_interface() { + assert new_match_expr_interface(.one).label() == 'one' + assert new_match_expr_interface(.two).label() == 'two' +} -- 2.39.5