From ebf629dc4124a6207b18028cf89dc6910d25eaf3 Mon Sep 17 00:00:00 2001 From: CreeperFace <165158232+dy-tea@users.noreply.github.com> Date: Mon, 10 Nov 2025 07:18:17 +0000 Subject: [PATCH] ast,checker: improve type checking for sumtypes with generics (fix #25690) (#25699) --- vlib/v/ast/table.v | 86 +++++++++++++++---- vlib/v/ast/types.v | 1 + vlib/v/checker/check_types.v | 46 ++++++++++ vlib/v/checker/fn.v | 27 ++++++ vlib/v/checker/struct.v | 47 ++++++++-- vlib/v/parser/struct.v | 2 + .../sumtypes/sumtype_type_coercion_test.v | 20 +++++ 7 files changed, 208 insertions(+), 21 deletions(-) create mode 100644 vlib/v/tests/sumtypes/sumtype_type_coercion_test.v diff --git a/vlib/v/ast/table.v b/vlib/v/ast/table.v index 696ebd748..2ee7da778 100644 --- a/vlib/v/ast/table.v +++ b/vlib/v/ast/table.v @@ -937,6 +937,9 @@ pub fn (mut t Table) register_sym(sym TypeSymbol) int { ...sym } t.type_symbols[idx].idx = idx + if t.type_symbols[idx].ngname == '' { + t.type_symbols[idx].ngname = strip_generic_params(sym.name) + } t.type_idxs[sym_name] = idx return idx } @@ -960,6 +963,11 @@ pub fn (t &Table) known_type(name string) bool { return t.type_idxs[name] != 0 || t.parsing_type == name || name in ['i32', 'byte'] } +@[inline] +pub fn strip_generic_params(name string) string { + return name.all_before('[') +} + // start_parsing_type open the scope during the parsing of a type // where the type name must include the module prefix pub fn (mut t Table) start_parsing_type(type_name string) { @@ -1178,6 +1186,7 @@ pub fn (mut t Table) find_or_register_chan(elem_type Type, is_mut bool) int { kind: .chan name: name cname: cname + ngname: strip_generic_params(name) info: Chan{ elem_type: elem_type is_mut: is_mut @@ -1199,6 +1208,7 @@ pub fn (mut t Table) find_or_register_map(key_type Type, value_type Type) int { kind: .map name: name cname: cname + ngname: strip_generic_params(name) info: Map{ key_type: key_type value_type: value_type @@ -1220,6 +1230,7 @@ pub fn (mut t Table) find_or_register_thread(return_type Type) int { kind: .thread name: name cname: cname + ngname: strip_generic_params(name) info: Thread{ return_type: return_type } @@ -1242,6 +1253,7 @@ pub fn (mut t Table) find_or_register_promise(return_type Type) int { kind: .struct name: name cname: cname + ngname: strip_generic_params(name) info: Struct{ concrete_types: [return_type, t.type_idxs['JS.Any']] } @@ -1265,6 +1277,7 @@ pub fn (mut t Table) find_or_register_array(elem_type Type) int { kind: .array name: name cname: cname + ngname: strip_generic_params(name) info: Array{ nr_dims: 1 elem_type: elem_type @@ -1292,10 +1305,11 @@ pub fn (mut t Table) find_or_register_array_fixed(elem_type Type, size int, size cname := prefix + t.array_fixed_cname(elem_type, size) // register array_fixed_type := TypeSymbol{ - kind: .array_fixed - name: name - cname: cname - info: ArrayFixed{ + kind: .array_fixed + name: name + cname: cname + ngname: strip_generic_params(name) + info: ArrayFixed{ elem_type: elem_type size: size size_expr: size_expr @@ -1328,10 +1342,11 @@ pub fn (mut t Table) find_or_register_multi_return(mr_typs []Type) int { return existing_idx } multireg_sym := TypeSymbol{ - kind: .multi_return - name: name - cname: cname - info: MultiReturn{ + kind: .multi_return + name: name + cname: cname + ngname: strip_generic_params(name) + info: MultiReturn{ types: mr_typs } } @@ -1354,11 +1369,12 @@ pub fn (mut t Table) find_or_register_fn_type(f Fn, is_anon bool, has_decl bool) return existing_idx } return t.register_sym( - kind: .function - name: name - cname: cname - mod: f.mod - info: FnType{ + kind: .function + name: name + cname: cname + ngname: strip_generic_params(name) + mod: f.mod + info: FnType{ is_anon: anon has_decl: has_decl func: f @@ -1366,6 +1382,44 @@ pub fn (mut t Table) find_or_register_fn_type(f Fn, is_anon bool, has_decl bool) ) } +pub fn (mut t Table) find_or_register_generic_inst(parent_typ Type, concrete_types []Type) int { + parent_sym := t.sym(parent_typ) + if parent_sym.info !is Struct { + return 0 + } + struct_info := parent_sym.info as Struct + if struct_info.generic_types.len == 0 || concrete_types.len != struct_info.generic_types.len { + return 0 + } + mut inst_name := parent_sym.ngname + '[' + mut inst_cname := parent_sym.cname + '_T_' + for i, ct in concrete_types { + ct_sym := t.sym(ct) + inst_name += ct_sym.name + inst_cname += ct_sym.cname + if i < concrete_types.len - 1 { + inst_name += ', ' + inst_cname += '_T_' + } + } + inst_name += ']' + existing_idx := t.type_idxs[inst_name] + if existing_idx > 0 { + return existing_idx + } + return t.register_sym( + kind: .generic_inst + name: inst_name + cname: inst_cname + ngname: parent_sym.ngname + mod: parent_sym.mod + info: GenericInst{ + parent_idx: parent_typ.idx() + concrete_types: concrete_types + } + ) +} + pub fn (mut t Table) add_placeholder_type(name string, cname string, language Language) int { mut modname := '' if name.contains('.') { @@ -1375,6 +1429,7 @@ pub fn (mut t Table) add_placeholder_type(name string, cname string, language La kind: .placeholder name: name cname: util.no_dots(cname).replace_each(['&', '']) + ngname: strip_generic_params(name) language: language mod: modname is_pub: true @@ -1893,7 +1948,7 @@ pub fn (mut t Table) convert_generic_type(generic_type Type, generic_names []str if sym.info.is_generic { mut nrt := '${sym.name}[' mut rnrt := '${sym.rname}[' - mut cnrt := '${sym.cname}[' + mut cnrt := '${sym.cname}_T_' mut t_generic_names := generic_names.clone() mut t_to_types := to_types.clone() if sym.generic_types.len > 0 && sym.generic_types.len == sym.info.generic_types.len @@ -1930,7 +1985,7 @@ pub fn (mut t Table) convert_generic_type(generic_type Type, generic_names []str if i != sym.info.generic_types.len - 1 { nrt += ', ' rnrt += ', ' - cnrt += ', ' + cnrt += '_' } } else { return none @@ -1938,7 +1993,6 @@ pub fn (mut t Table) convert_generic_type(generic_type Type, generic_names []str } nrt += ']' rnrt += ']' - cnrt += ']' mut idx := t.type_idxs[nrt] if idx == 0 { idx = t.type_idxs[rnrt] diff --git a/vlib/v/ast/types.v b/vlib/v/ast/types.v index a575906c0..c7928c8e5 100644 --- a/vlib/v/ast/types.v +++ b/vlib/v/ast/types.v @@ -115,6 +115,7 @@ pub mut: name string // the internal & source name of the type, i.e. `[5]int`. cname string // the name with no dots for use in the generated C code rname string // the raw name + ngname string // the name without generic parameters methods []Fn generic_types []Type mod string diff --git a/vlib/v/checker/check_types.v b/vlib/v/checker/check_types.v index 72959277c..9dccd0e2e 100644 --- a/vlib/v/checker/check_types.v +++ b/vlib/v/checker/check_types.v @@ -460,6 +460,21 @@ fn (mut c Checker) check_basic(got ast.Type, expected ast.Type) bool { if c.table.sumtype_has_variant(expected, ast.mktyp(got), false) { return true } + if exp_sym.kind == .placeholder && c.expected_type != ast.void_type { + base_type := c.table.find_type(exp_sym.ngname) + if base_type != 0 { + base_sym := c.table.sym(base_type) + if base_sym.kind == .sum_type && base_sym.info is ast.SumType { + base_info := base_sym.info as ast.SumType + for variant in base_info.variants { + variant_sym := c.table.sym(variant) + if variant_sym.ngname == got_sym.ngname { + return true + } + } + } + } + } // struct if exp_sym.kind == .struct && got_sym.kind == .struct { if c.table.type_to_str(expected) == c.table.type_to_str(got) { @@ -950,6 +965,37 @@ fn (mut c Checker) infer_struct_generic_types(typ ast.Type, node ast.StructInit) } } } + } else if field_sym.info is ast.SumType { + for t in node.init_fields { + if ft.name == t.name && t.typ != 0 { + init_sym := c.table.sym(t.typ) + for variant in field_sym.info.variants { + variant_sym := c.table.sym(variant) + if variant_sym.name == init_sym.name { + if variant_sym.info is ast.Struct + && variant_sym.info.generic_types.len > 0 { + if init_sym.info is ast.Struct + && init_sym.info.concrete_types.len > 0 { + concrete_types << ast.mktyp(init_sym.info.concrete_types[0]) + continue gname + } + } else { + for init_field in node.init_fields { + if init_field.name != t.name && init_field.typ != 0 { + field := sym.info.fields.filter(it.name == init_field.name) + if field.len > 0 { + if c.table.sym(field[0].typ).name == gt_name { + concrete_types << ast.mktyp(init_field.typ) + continue gname + } + } + } + } + } + } + } + } + } } } c.error('could not infer generic type `${gt_name}` in generic struct `${sym.name}[${generic_names.join(', ')}]`', diff --git a/vlib/v/checker/fn.v b/vlib/v/checker/fn.v index e9385ffa6..fa0ad9135 100644 --- a/vlib/v/checker/fn.v +++ b/vlib/v/checker/fn.v @@ -1569,6 +1569,33 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast. node.args[i].typ = call_arg.expr.obj.typ } } + // sumtype coercion + param_type_sym := c.table.sym(param.typ) + if param_type_sym.kind == .placeholder { + base_type := c.table.find_type(param_type_sym.ngname) + if base_type != 0 { + base_sym := c.table.sym(base_type) + if base_sym.kind == .sum_type && base_sym.info is ast.SumType { + base_info := base_sym.info as ast.SumType + arg_typ_sym := c.table.sym(arg_typ) + for variant in base_info.variants { + variant_sym := c.table.sym(variant) + variant_base_name := variant_sym.ngname + if variant_base_name == arg_typ_sym.ngname { + node.args[i].expr = ast.CastExpr{ + expr: call_arg.expr + typ: param.typ + typname: c.table.type_to_str(param.typ) + pos: call_arg.expr.pos() + } + node.args[i].typ = param.typ + arg_typ = param.typ + break + } + } + } + } + } arg_typ_sym := c.table.sym(arg_typ) if param.typ.has_flag(.generic) { if arg_typ_sym.kind == .none && !param.typ.has_flag(.option) { diff --git a/vlib/v/checker/struct.v b/vlib/v/checker/struct.v index a34899ca5..55a096d75 100644 --- a/vlib/v/checker/struct.v +++ b/vlib/v/checker/struct.v @@ -700,6 +700,14 @@ fn (mut c Checker) struct_init(mut node ast.StructInit, is_field_zero_struct_ini mut exp_type := ast.no_type inited_fields << field_name exp_type = field_info.typ + if c.inside_generic_struct_init && exp_type.has_flag(.generic) { + generic_names := c.cur_struct_generic_types.map(c.table.sym(it).name) + if unwrapped := c.table.convert_generic_type(exp_type, generic_names, + c.cur_struct_concrete_types) + { + exp_type = unwrapped + } + } exp_type_sym := c.table.sym(exp_type) c.expected_type = exp_type got_type = c.expr(mut init_field.expr) @@ -783,9 +791,35 @@ or use an explicit `unsafe{ a[..] }`, if you do not want a copy of the slice.', } } else if got_type != ast.void_type && got_type_sym.kind != .placeholder && !exp_type.has_flag(.generic) { - c.check_expected(c.unwrap_generic(got_type), c.unwrap_generic(exp_type)) or { - c.error('cannot assign to field `${field_info.name}`: ${err.msg()}', - init_field.pos) + mut needs_sum_type_cast := false + if exp_type_sym.kind == .placeholder { + base_type := c.table.find_type(exp_type_sym.ngname) + if base_type != 0 { + base_sym := c.table.sym(base_type) + if base_sym.kind == .sum_type && base_sym.info is ast.SumType { + base_info := base_sym.info as ast.SumType + for variant in base_info.variants { + if c.table.sym(variant).ngname == got_type_sym.ngname { + needs_sum_type_cast = true + break + } + } + } + } + } + if needs_sum_type_cast { + init_field.expr = ast.CastExpr{ + expr: init_field.expr + typ: exp_type + typname: c.table.type_to_str(exp_type) + pos: init_field.expr.pos() + } + init_field.typ = exp_type + } else { + c.check_expected(c.unwrap_generic(got_type), c.unwrap_generic(exp_type)) or { + c.error('cannot assign to field `${field_info.name}`: ${err.msg()}', + init_field.pos) + } } } if exp_type.has_flag(.shared_f) { @@ -936,8 +970,11 @@ or use an explicit `unsafe{ a[..] }`, if you do not want a copy of the slice.', if struct_sym.info.concrete_types.len == 0 { concrete_types := c.infer_struct_generic_types(node.typ, node) if concrete_types.len > 0 { - generic_names := struct_sym.info.generic_types.map(c.table.sym(it).name) - node.typ = c.table.unwrap_generic_type(node.typ, generic_names, concrete_types) + idx := c.table.find_or_register_generic_inst(node.typ, concrete_types) + if idx > 0 { + node.typ = ast.new_type(idx) + c.table.generic_insts_to_concrete() + } } } else if struct_sym.info.generic_types.len == struct_sym.info.concrete_types.len { parent_type := struct_sym.info.parent_type diff --git a/vlib/v/parser/struct.v b/vlib/v/parser/struct.v index 29e856f60..d2063528e 100644 --- a/vlib/v/parser/struct.v +++ b/vlib/v/parser/struct.v @@ -433,6 +433,7 @@ fn (mut p Parser) struct_decl(is_anon bool) ast.StructDecl { language: language name: name cname: util.no_dots(name) + ngname: ast.strip_generic_params(name) mod: p.mod info: ast.Struct{ scoped_name: scoped_name @@ -727,6 +728,7 @@ fn (mut p Parser) interface_decl() ast.InterfaceDecl { kind: .interface name: interface_name cname: util.no_dots(interface_name) + ngname: ast.strip_generic_params(interface_name) mod: p.mod info: ast.Interface{ types: [] diff --git a/vlib/v/tests/sumtypes/sumtype_type_coercion_test.v b/vlib/v/tests/sumtypes/sumtype_type_coercion_test.v new file mode 100644 index 000000000..ea95d4549 --- /dev/null +++ b/vlib/v/tests/sumtypes/sumtype_type_coercion_test.v @@ -0,0 +1,20 @@ +struct Empty {} + +struct Node[T] { + value T + next Chain[T] +} + +type Chain[T] = Empty | Node[T] + +fn get[T](chain Chain[T]) T { + return match chain { + Empty { 0 } + Node[T] { chain.value } + } +} + +fn test_main() { + chain := Node{0.2, Empty{}} + assert get(chain) == 0.2 +} -- 2.39.5