From e5100a8f7366043b9ad58ad145703ec81d3ae576 Mon Sep 17 00:00:00 2001 From: Swastik Baranwal Date: Tue, 12 Aug 2025 22:01:34 +0530 Subject: [PATCH] cgen: allow alias sumtype smartcasting (fix #25085) (#25096) --- vlib/v/checker/if.v | 5 +++-- vlib/v/checker/infix.v | 2 +- vlib/v/gen/c/cgen.v | 2 +- vlib/v/gen/c/infix.v | 2 +- .../tests/sumtypes/alias_sumtype_smartcast_test.v | 14 ++++++++++++++ 5 files changed, 20 insertions(+), 5 deletions(-) create mode 100644 vlib/v/tests/sumtypes/alias_sumtype_smartcast_test.v diff --git a/vlib/v/checker/if.v b/vlib/v/checker/if.v index 1c2be9e9c..3f48fe033 100644 --- a/vlib/v/checker/if.v +++ b/vlib/v/checker/if.v @@ -707,6 +707,7 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope, co right_sym := c.table.sym(right_type) mut expr_type := c.unwrap_generic(node.left_type) left_sym := c.table.sym(expr_type) + left_final_sym := c.table.final_sym(expr_type) if left_sym.kind == .aggregate { expr_type = (left_sym.info as ast.Aggregate).sum_type } @@ -714,7 +715,7 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope, co if right_sym.kind != .interface { c.type_implements(right_type, expr_type, node.pos) } - } else if !c.check_types(right_type, expr_type) && left_sym.kind != .sum_type { + } else if !c.check_types(right_type, expr_type) && left_final_sym.kind != .sum_type { expect_str := c.table.type_to_str(right_type) expr_str := c.table.type_to_str(expr_type) c.error('cannot use type `${expect_str}` as type `${expr_str}`', node.pos) @@ -740,7 +741,7 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope, co node.left.pos) } } - if left_sym.kind in [.interface, .sum_type] { + if left_final_sym.kind in [.interface, .sum_type] { c.smartcast(mut node.left, node.left_type, right_type, mut scope, is_comptime, false) } diff --git a/vlib/v/checker/infix.v b/vlib/v/checker/infix.v index 6378677a5..1493278d7 100644 --- a/vlib/v/checker/infix.v +++ b/vlib/v/checker/infix.v @@ -736,7 +736,7 @@ fn (mut c Checker) infix_expr(mut node ast.InfixExpr) ast.Type { if typ != ast.none_type_idx { c.error('`${op}` can only be used to test for none in sql', node.pos) } - } else if left_sym.kind !in [.interface, .sum_type] + } else if left_final_sym.kind !in [.interface, .sum_type] && !c.comptime.is_comptime(node.left) { c.error('`${op}` can only be used with interfaces and sum types', node.pos) // can be used in sql too, but keep err simple diff --git a/vlib/v/gen/c/cgen.v b/vlib/v/gen/c/cgen.v index 9e8ee8c8b..f45150d7c 100644 --- a/vlib/v/gen/c/cgen.v +++ b/vlib/v/gen/c/cgen.v @@ -5341,7 +5341,7 @@ fn (mut g Gen) ident(node ast.Ident) { } is_option = is_option || node.obj.orig_type.has_flag(.option) if node.obj.smartcasts.len > 0 { - obj_sym := g.table.sym(g.unwrap_generic(node.obj.typ)) + obj_sym := g.table.final_sym(g.unwrap_generic(node.obj.typ)) if !prevent_sum_type_unwrapping_once { nested_unwrap := node.obj.smartcasts.len > 1 unwrap_sumtype := is_option && nested_unwrap && obj_sym.kind == .sum_type diff --git a/vlib/v/gen/c/infix.v b/vlib/v/gen/c/infix.v index e82905e27..1173607fb 100644 --- a/vlib/v/gen/c/infix.v +++ b/vlib/v/gen/c/infix.v @@ -845,7 +845,7 @@ fn (mut g Gen) infix_expr_in_optimization(left ast.Expr, left_type ast.Type, rig // infix_expr_is_op generates code for `is` and `!is` fn (mut g Gen) infix_expr_is_op(node ast.InfixExpr) { - mut left_sym := g.table.sym(g.unwrap_generic(g.type_resolver.get_type_or_default(node.left, + mut left_sym := g.table.final_sym(g.unwrap_generic(g.type_resolver.get_type_or_default(node.left, node.left_type))) is_aggregate := node.left is ast.Ident && g.comptime.get_ct_type_var(node.left) == .aggregate right_sym := g.table.sym(node.right_type) diff --git a/vlib/v/tests/sumtypes/alias_sumtype_smartcast_test.v b/vlib/v/tests/sumtypes/alias_sumtype_smartcast_test.v new file mode 100644 index 000000000..70bade250 --- /dev/null +++ b/vlib/v/tests/sumtypes/alias_sumtype_smartcast_test.v @@ -0,0 +1,14 @@ +type Sum = int | string +type SumAlias = Sum + +fn test_alias_sumtype_smartcast() { + a_int := SumAlias(Sum(10)) + if a_int is int { + assert a_int == 10 + } + + a_str := SumAlias('foo') + if a_str is string { + assert a_str == 'foo' + } +} -- 2.39.5