From f941bb4af0206fd5bd19c41759ef869c02380468 Mon Sep 17 00:00:00 2001 From: Felipe Pena Date: Mon, 4 Dec 2023 07:19:52 -0300 Subject: [PATCH] checker, cgen: fix comptimecall type resolution on function args (#20070) --- vlib/v/checker/checker.v | 3 + vlib/v/checker/comptime.v | 63 ++++++++++++++++++- vlib/v/checker/fn.v | 2 + .../comptime_field_selector_not_name_err.out | 14 +++++ vlib/v/gen/c/fn.v | 8 +++ vlib/v/tests/method_call_resolve_test.v | 45 +++++++++++++ 6 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 vlib/v/tests/method_call_resolve_test.v diff --git a/vlib/v/checker/checker.v b/vlib/v/checker/checker.v index bfa27d199..5ff9824f1 100644 --- a/vlib/v/checker/checker.v +++ b/vlib/v/checker/checker.v @@ -113,6 +113,9 @@ mut: comptime_fields_type map[string]ast.Type comptime_for_field_value ast.StructField // value of the field variable comptime_enum_field_value string // current enum value name + comptime_for_method string // $for method in T.methods {} + comptime_for_method_var string // $for method in T.methods {}; the variable name + comptime_values_stack []CurrentComptimeValues // stores the values from the above on each $for loop, to make nesting them easier fn_scope &ast.Scope = unsafe { nil } main_fn_decl_node ast.FnDecl match_exhaustive_cutoff_limit int = 10 diff --git a/vlib/v/checker/comptime.v b/vlib/v/checker/comptime.v index 71d99c2c7..5498ae894 100644 --- a/vlib/v/checker/comptime.v +++ b/vlib/v/checker/comptime.v @@ -44,6 +44,15 @@ fn (mut c Checker) get_comptime_var_type(node ast.Expr) ast.Type { } else if node is ast.SelectorExpr && c.is_comptime_selector_type(node) { // field_var.typ from $for field return c.comptime_fields_default_type + } else if node is ast.ComptimeCall { + method_name := c.comptime_for_method + left_sym := c.table.sym(c.unwrap_generic(node.left_type)) + f := left_sym.find_method(method_name) or { + c.error('could not find method `${method_name}` on compile-time resolution', + node.method_pos) + return ast.void_type + } + return f.return_type } return ast.void_type } @@ -274,6 +283,7 @@ fn (mut c Checker) comptime_for(mut node ast.ComptimeFor) { return } } + c.push_existing_comptime_values() c.inside_comptime_for_field = true for field in fields { if c.field_data_type == 0 { @@ -298,8 +308,7 @@ fn (mut c Checker) comptime_for(mut node ast.ComptimeFor) { } } } - c.comptime_for_field_var = '' - c.inside_comptime_for_field = false + c.pop_existing_comptime_values() } else if c.table.generic_type_names(node.typ).len == 0 && sym.kind != .placeholder { c.error('iterating over .fields is supported only for structs and interfaces, and ${sym.name} is neither', node.typ_pos) @@ -307,6 +316,7 @@ fn (mut c Checker) comptime_for(mut node ast.ComptimeFor) { } } else if node.kind == .values { if sym.kind == .enum_ { + c.push_existing_comptime_values() sym_info := sym.info as ast.Enum c.inside_comptime_for_field = true if c.enum_data_type == 0 { @@ -319,11 +329,24 @@ fn (mut c Checker) comptime_for(mut node ast.ComptimeFor) { c.comptime_fields_type['${node.val_var}.typ'] = node.typ c.stmts(mut node.stmts) } + c.pop_existing_comptime_values() } else { c.error('iterating over .values is supported only for enums, and ${sym.name} is not an enum', node.typ_pos) return } + } else if node.kind == .methods { + mut methods := sym.methods.filter(it.attrs.len == 0) // methods without attrs first + methods_with_attrs := sym.methods.filter(it.attrs.len > 0) // methods with attrs second + methods << methods_with_attrs + + c.push_existing_comptime_values() + for method in methods { + c.comptime_for_method = method.name + c.comptime_for_method_var = node.val_var + c.stmts(mut node.stmts) + } + c.pop_existing_comptime_values() } else { c.stmts(mut node.stmts) } @@ -1036,3 +1059,39 @@ fn (mut c Checker) get_comptime_selector_bool_field(field_name string) bool { else { return false } } } + +struct CurrentComptimeValues { + inside_comptime_for_field bool + comptime_for_field_var string + comptime_fields_default_type ast.Type + comptime_fields_type map[string]ast.Type + comptime_for_field_value ast.StructField + comptime_enum_field_value string + comptime_for_method string + comptime_for_method_var string +} + +fn (mut c Checker) push_existing_comptime_values() { + c.comptime_values_stack << CurrentComptimeValues{ + inside_comptime_for_field: c.inside_comptime_for_field + comptime_for_field_var: c.comptime_for_field_var + comptime_fields_default_type: c.comptime_fields_default_type + comptime_fields_type: c.comptime_fields_type.clone() + comptime_for_field_value: c.comptime_for_field_value + comptime_enum_field_value: c.comptime_enum_field_value + comptime_for_method: c.comptime_for_method + comptime_for_method_var: c.comptime_for_method_var + } +} + +fn (mut c Checker) pop_existing_comptime_values() { + old := c.comptime_values_stack.pop() + c.inside_comptime_for_field = old.inside_comptime_for_field + c.comptime_for_field_var = old.comptime_for_field_var + c.comptime_fields_default_type = old.comptime_fields_default_type + c.comptime_fields_type = old.comptime_fields_type.clone() + c.comptime_for_field_value = old.comptime_for_field_value + c.comptime_enum_field_value = old.comptime_enum_field_value + c.comptime_for_method = old.comptime_for_method + c.comptime_for_method_var = old.comptime_for_method_var +} diff --git a/vlib/v/checker/fn.v b/vlib/v/checker/fn.v index 993fce032..30faf67dc 100644 --- a/vlib/v/checker/fn.v +++ b/vlib/v/checker/fn.v @@ -1592,6 +1592,8 @@ fn (mut c Checker) get_comptime_args(func ast.Fn, node_ ast.CallExpr, concrete_t } else if call_arg.expr is ast.ComptimeSelector && c.table.is_comptime_var(call_arg.expr) { comptime_args[i] = c.get_comptime_var_type(call_arg.expr) + } else if call_arg.expr is ast.ComptimeCall { + comptime_args[i] = c.get_comptime_var_type(call_arg.expr) } } } diff --git a/vlib/v/checker/tests/comptime_field_selector_not_name_err.out b/vlib/v/checker/tests/comptime_field_selector_not_name_err.out index 446122674..58b3560b4 100644 --- a/vlib/v/checker/tests/comptime_field_selector_not_name_err.out +++ b/vlib/v/checker/tests/comptime_field_selector_not_name_err.out @@ -40,3 +40,17 @@ vlib/v/checker/tests/comptime_field_selector_not_name_err.vv:15:12: error: expec | ~~~~ 16 | } 17 | +vlib/v/checker/tests/comptime_field_selector_not_name_err.vv:15:10: error: compile time field access can only be used when iterating over `T.fields` + 13 | } + 14 | } + 15 | _ = t.$(f.name) + | ^ + 16 | } + 17 | +vlib/v/checker/tests/comptime_field_selector_not_name_err.vv:15:10: error: unknown `$for` variable `f` + 13 | } + 14 | } + 15 | _ = t.$(f.name) + | ^ + 16 | } + 17 | diff --git a/vlib/v/gen/c/fn.v b/vlib/v/gen/c/fn.v index f38597fb1..15763ab75 100644 --- a/vlib/v/gen/c/fn.v +++ b/vlib/v/gen/c/fn.v @@ -1154,6 +1154,14 @@ fn (mut g Gen) change_comptime_args(func ast.Fn, mut node_ ast.CallExpr, concret if param_typ.nr_muls() > 0 && comptime_args[i].nr_muls() > 0 { comptime_args[i] = comptime_args[i].set_nr_muls(0) } + } else if mut call_arg.expr is ast.ComptimeCall { + if call_arg.expr.method_name == 'method' { + sym := g.table.sym(g.unwrap_generic(call_arg.expr.left_type)) + // `app.$method()` + if m := sym.find_method(g.comptime_for_method) { + comptime_args[i] = m.return_type + } + } } } } diff --git a/vlib/v/tests/method_call_resolve_test.v b/vlib/v/tests/method_call_resolve_test.v new file mode 100644 index 000000000..7bcf90f1d --- /dev/null +++ b/vlib/v/tests/method_call_resolve_test.v @@ -0,0 +1,45 @@ +struct Human { + name string +} + +enum Animal { + dog + cat +} + +type Entity = Animal | Human + +@[sumtype_to: Animal] +fn (ent Entity) json_cast_to_animal() Animal { + return ent as Animal +} + +@[sumtype_to: Human] +fn (ent Entity) json_cast_to_human() Human { + return ent as Human +} + +fn encode[T](val T) { + $if T is $sumtype { + $for method in T.methods { + if method.attrs.len >= 1 { + if method.attrs[0].contains('sumtype_to') { + if val.type_name() == method.attrs[0].all_after('sumtype_to:').trim_space() { + encode(val.$method()) + } + } + } + } + } $else $if T is $struct { + assert val == Human{ + name: 'Monke' + } + } $else $if T is $enum { + assert val == Animal.cat + } +} + +fn test_main() { + encode(Entity(Human{'Monke'})) + encode(Entity(Animal.cat)) +} -- 2.39.5