From 350747793bfdf74d3be70ea3df075706871b0099 Mon Sep 17 00:00:00 2001 From: kbkpbot Date: Mon, 11 Aug 2025 14:01:49 +0800 Subject: [PATCH] math.bits: add asm implementations for some 64 bit ops (#25020) --- vlib/math/bits/bits.amd64.v | 100 +++++++++++++++++++ vlib/math/bits/bits.arm64.v | 55 +++++++++++ vlib/math/bits/bits.c.v | 186 ++++++++++++++++++++++++++++++++++++ vlib/math/bits/bits.v | 131 +++++++++++++++++++++++-- vlib/math/bits/bits_test.v | 52 ++++++++++ 5 files changed, 517 insertions(+), 7 deletions(-) create mode 100644 vlib/math/bits/bits.amd64.v create mode 100644 vlib/math/bits/bits.arm64.v create mode 100644 vlib/math/bits/bits.c.v diff --git a/vlib/math/bits/bits.amd64.v b/vlib/math/bits/bits.amd64.v new file mode 100644 index 000000000..ba3d6b53a --- /dev/null +++ b/vlib/math/bits/bits.amd64.v @@ -0,0 +1,100 @@ +// Copyright (c) 2019-2024 Alexander Medvednikov. All rights reserved. +// Use of this source code is governed by an MIT license +// that can be found in the LICENSE file. +module bits + +fn C._umul128(x u64, y u64, result_hi &u64) u64 +fn C._addcarry_u64(carry_in u8, a u64, b u64, out &u64) u8 +fn C._udiv128(hi u64, lo u64, y u64, rem &u64) u64 + +// mul_64 returns the 128-bit product of x and y: (hi, lo) = x * y +// with the product bits' upper half returned in hi and the lower +// half returned in lo. +// +// This function's execution time does not depend on the inputs. +@[inline] +pub fn mul_64(x u64, y u64) (u64, u64) { + mut hi := u64(0) + mut lo := u64(0) + $if msvc { + lo = C._umul128(x, y, &hi) + return hi, lo + } $else $if amd64 { + asm amd64 { + mulq rdx + ; =a (lo) + =d (hi) + ; a (x) + d (y) + ; cc + } + return hi, lo + } + // cross compile + return mul_64_default(x, y) +} + +// mul_add_64 returns the 128-bit result of x * y + z: (hi, lo) = x * y + z +// with the result bits' upper half returned in hi and the lower +// half returned in lo. +@[inline] +pub fn mul_add_64(x u64, y u64, z u64) (u64, u64) { + mut hi := u64(0) + mut lo := u64(0) + $if msvc { + lo = C._umul128(x, y, &hi) + carry := C._addcarry_u64(0, lo, z, &lo) + hi += carry + return hi, lo + } $else $if amd64 { + asm amd64 { + mulq rdx + addq rax, z + adcq rdx, 0 + ; =a (lo) + =d (hi) + ; a (x) + d (y) + r (z) + ; cc + } + return hi, lo + } + // cross compile + return mul_add_64_default(x, y, z) +} + +// div_64 returns the quotient and remainder of (hi, lo) divided by y: +// quo = (hi, lo)/y, rem = (hi, lo)%y with the dividend bits' upper +// half in parameter hi and the lower half in parameter lo. +// div_64 panics for y == 0 (division by zero) or y <= hi (quotient overflow). +@[inline] +pub fn div_64(hi u64, lo u64, y1 u64) (u64, u64) { + mut y := y1 + if y == 0 { + panic(overflow_error) + } + if y <= hi { + panic(overflow_error) + } + + mut quo := u64(0) + mut rem := u64(0) + $if msvc { + quo = C._udiv128(hi, lo, y, &rem) + return quo, rem + } $else $if amd64 { + asm amd64 { + div y + ; =a (quo) + =d (rem) + ; d (hi) + a (lo) + r (y) + ; cc + } + return quo, rem + } + // cross compile + return div_64_default(hi, lo, y1) +} diff --git a/vlib/math/bits/bits.arm64.v b/vlib/math/bits/bits.arm64.v new file mode 100644 index 000000000..110831e2e --- /dev/null +++ b/vlib/math/bits/bits.arm64.v @@ -0,0 +1,55 @@ +// Copyright (c) 2019-2024 Alexander Medvednikov. All rights reserved. +// Use of this source code is governed by an MIT license +// that can be found in the LICENSE file. +module bits + +// mul_64 returns the 128-bit product of x and y: (hi, lo) = x * y +// with the product bits' upper half returned in hi and the lower +// half returned in lo. +// +// This function's execution time does not depend on the inputs. +@[inline] +pub fn mul_64(x u64, y u64) (u64, u64) { + mut hi := u64(0) + mut lo := u64(0) + $if arm64 && !tinyc { + asm arm64 { + mul lo, x, y + umulh hi, x, y + ; =&r (hi) + =&r (lo) + ; r (x) + r (y) + ; cc + } + return hi, lo + } + // cross compile + return mul_64_default(x, y) +} + +// mul_add_64 returns the 128-bit result of x * y + z: (hi, lo) = x * y + z +// with the result bits' upper half returned in hi and the lower +// half returned in lo. +@[inline] +pub fn mul_add_64(x u64, y u64, z u64) (u64, u64) { + mut hi := u64(0) + mut lo := u64(0) + $if arm64 && !tinyc { + asm arm64 { + mul lo, x, y + umulh hi, x, y + adds lo, lo, z + adc hi, hi, xzr + ; =&r (hi) + =&r (lo) + ; r (x) + r (y) + r (z) + ; cc + } + return hi, lo + } + // cross compile + return mul_add_64_default(x, y, z) +} diff --git a/vlib/math/bits/bits.c.v b/vlib/math/bits/bits.c.v new file mode 100644 index 000000000..b4b595a74 --- /dev/null +++ b/vlib/math/bits/bits.c.v @@ -0,0 +1,186 @@ +// Copyright (c) 2019-2024 Alexander Medvednikov. All rights reserved. +// Use of this source code is governed by an MIT license +// that can be found in the LICENSE file. +module bits + +fn C.__builtin_clz(x u32) int +fn C.__builtin_clzll(x u64) int +fn C.__lzcnt(x u32) int +fn C.__lzcnt64(x u64) int + +// --- LeadingZeros --- +// leading_zeros_8 returns the number of leading zero bits in x; the result is 8 for x == 0. +@[inline] +pub fn leading_zeros_8(x u8) int { + if x == 0 { + return 8 + } + $if msvc { + return C.__lzcnt(x) - 24 + } $else $if !tinyc { + return C.__builtin_clz(x) - 24 + } + return leading_zeros_8_default(x) +} + +// leading_zeros_16 returns the number of leading zero bits in x; the result is 16 for x == 0. +@[inline] +pub fn leading_zeros_16(x u16) int { + if x == 0 { + return 16 + } + $if msvc { + return C.__lzcnt(x) - 16 + } $else $if !tinyc { + return C.__builtin_clz(x) - 16 + } + return leading_zeros_16_default(x) +} + +// leading_zeros_32 returns the number of leading zero bits in x; the result is 32 for x == 0. +@[inline] +pub fn leading_zeros_32(x u32) int { + if x == 0 { + return 32 + } + $if msvc { + return C.__lzcnt(x) + } $else $if !tinyc { + return C.__builtin_clz(x) + } + return leading_zeros_32_default(x) +} + +// leading_zeros_64 returns the number of leading zero bits in x; the result is 64 for x == 0. +@[inline] +pub fn leading_zeros_64(x u64) int { + if x == 0 { + return 64 + } + $if msvc { + return C.__lzcnt64(x) + } $else $if !tinyc { + return C.__builtin_clzll(x) + } + return leading_zeros_64_default(x) +} + +fn C.__builtin_ctz(x u32) int +fn C.__builtin_ctzll(x u64) int +fn C._BitScanForward(pos &int, x u32) u8 +fn C._BitScanForward64(pos &int, x u64) u8 + +// --- TrailingZeros --- +// trailing_zeros_8 returns the number of trailing zero bits in x; the result is 8 for x == 0. +@[inline] +pub fn trailing_zeros_8(x u8) int { + if x == 0 { + return 8 + } + $if msvc { + mut pos := 0 + _ := C._BitScanForward(&pos, x) + return pos + } $else $if !tinyc { + return C.__builtin_ctz(x) + } + return trailing_zeros_8_default(x) +} + +// trailing_zeros_16 returns the number of trailing zero bits in x; the result is 16 for x == 0. +@[inline] +pub fn trailing_zeros_16(x u16) int { + if x == 0 { + return 16 + } + $if msvc { + mut pos := 0 + _ := C._BitScanForward(&pos, x) + return pos + } $else $if !tinyc { + return C.__builtin_ctz(x) + } + return trailing_zeros_16_default(x) +} + +// trailing_zeros_32 returns the number of trailing zero bits in x; the result is 32 for x == 0. +@[inline] +pub fn trailing_zeros_32(x u32) int { + if x == 0 { + return 32 + } + $if msvc { + mut pos := 0 + _ := C._BitScanForward(&pos, x) + return pos + } $else $if !tinyc { + return C.__builtin_ctz(x) + } + return trailing_zeros_32_default(x) +} + +// trailing_zeros_64 returns the number of trailing zero bits in x; the result is 64 for x == 0. +@[inline] +pub fn trailing_zeros_64(x u64) int { + if x == 0 { + return 64 + } + $if msvc { + mut pos := 0 + _ := C._BitScanForward64(&pos, x) + return pos + } $else $if !tinyc { + return C.__builtin_ctzll(x) + } + return trailing_zeros_64_default(x) +} + +fn C.__builtin_popcount(x u32) int +fn C.__builtin_popcountll(x u64) int +fn C.__popcnt(x u32) int +fn C.__popcnt64(x u64) int + +// --- OnesCount --- +// ones_count_8 returns the number of one bits ("population count") in x. +@[inline] +pub fn ones_count_8(x u8) int { + $if msvc { + return C.__popcnt(x) + } $else $if !tinyc { + return C.__builtin_popcount(x) + } + return ones_count_8_default(x) +} + +// ones_count_16 returns the number of one bits ("population count") in x. +@[inline] +pub fn ones_count_16(x u16) int { + $if msvc { + return C.__popcnt(x) + } $else $if !tinyc { + return C.__builtin_popcount(x) + } + return ones_count_16_default(x) +} + +// ones_count_32 returns the number of one bits ("population count") in x. +@[inline] +pub fn ones_count_32(x u32) int { + $if msvc { + return C.__popcnt(x) + } $else $if !tinyc { + return C.__builtin_popcount(x) + } + return ones_count_32_default(x) +} + +// ones_count_64 returns the number of one bits ("population count") in x. +@[inline] +pub fn ones_count_64(x u64) int { + $if msvc { + return C.__popcnt64(x) + } $else $if !tinyc { + return C.__builtin_popcountll(x) + } + return ones_count_64_default(x) +} diff --git a/vlib/math/bits/bits.v b/vlib/math/bits/bits.v index 72cd732bb..e968d3c74 100644 --- a/vlib/math/bits/bits.v +++ b/vlib/math/bits/bits.v @@ -24,35 +24,69 @@ const m4 = u64(0x0000ffff0000ffff) // --- LeadingZeros --- // leading_zeros_8 returns the number of leading zero bits in x; the result is 8 for x == 0. +@[inline] pub fn leading_zeros_8(x u8) int { + return leading_zeros_8_default(x) +} + +@[inline] +fn leading_zeros_8_default(x u8) int { return 8 - len_8(x) } // leading_zeros_16 returns the number of leading zero bits in x; the result is 16 for x == 0. +@[inline] pub fn leading_zeros_16(x u16) int { + return leading_zeros_16_default(x) +} + +@[inline] +fn leading_zeros_16_default(x u16) int { return 16 - len_16(x) } // leading_zeros_32 returns the number of leading zero bits in x; the result is 32 for x == 0. +@[inline] pub fn leading_zeros_32(x u32) int { + return leading_zeros_32_default(x) +} + +@[inline] +fn leading_zeros_32_default(x u32) int { return 32 - len_32(x) } // leading_zeros_64 returns the number of leading zero bits in x; the result is 64 for x == 0. +@[inline] pub fn leading_zeros_64(x u64) int { + return leading_zeros_64_default(x) +} + +@[inline] +fn leading_zeros_64_default(x u64) int { return 64 - len_64(x) } // --- TrailingZeros --- // trailing_zeros_8 returns the number of trailing zero bits in x; the result is 8 for x == 0. -@[direct_array_access] +@[inline] pub fn trailing_zeros_8(x u8) int { + return trailing_zeros_8_default(x) +} + +@[direct_array_access; inline] +fn trailing_zeros_8_default(x u8) int { return int(ntz_8_tab[x]) } // trailing_zeros_16 returns the number of trailing zero bits in x; the result is 16 for x == 0. -@[direct_array_access] +@[inline] pub fn trailing_zeros_16(x u16) int { + return trailing_zeros_16_default(x) +} + +@[direct_array_access; inline] +fn trailing_zeros_16_default(x u16) int { if x == 0 { return 16 } @@ -61,8 +95,13 @@ pub fn trailing_zeros_16(x u16) int { } // trailing_zeros_32 returns the number of trailing zero bits in x; the result is 32 for x == 0. -@[direct_array_access] +@[inline] pub fn trailing_zeros_32(x u32) int { + return trailing_zeros_32_default(x) +} + +@[direct_array_access; inline] +fn trailing_zeros_32_default(x u32) int { if x == 0 { return 32 } @@ -71,8 +110,13 @@ pub fn trailing_zeros_32(x u32) int { } // trailing_zeros_64 returns the number of trailing zero bits in x; the result is 64 for x == 0. -@[direct_array_access] +@[inline] pub fn trailing_zeros_64(x u64) int { + return trailing_zeros_64_default(x) +} + +@[direct_array_access; inline] +fn trailing_zeros_64_default(x u64) int { if x == 0 { return 64 } @@ -92,26 +136,46 @@ pub fn trailing_zeros_64(x u64) int { // --- OnesCount --- // ones_count_8 returns the number of one bits ("population count") in x. -@[direct_array_access] +@[inline] pub fn ones_count_8(x u8) int { + return ones_count_8_default(x) +} + +@[direct_array_access; inline] +fn ones_count_8_default(x u8) int { return int(pop_8_tab[x]) } // ones_count_16 returns the number of one bits ("population count") in x. -@[direct_array_access] +@[inline] pub fn ones_count_16(x u16) int { + return ones_count_16_default(x) +} + +@[direct_array_access; inline] +fn ones_count_16_default(x u16) int { return int(pop_8_tab[x >> 8] + pop_8_tab[x & u16(0xff)]) } // ones_count_32 returns the number of one bits ("population count") in x. -@[direct_array_access] +@[inline] pub fn ones_count_32(x u32) int { + return ones_count_32_default(x) +} + +@[direct_array_access; inline] +fn ones_count_32_default(x u32) int { return int(pop_8_tab[x >> 24] + pop_8_tab[(x >> 16) & 0xff] + pop_8_tab[(x >> 8) & 0xff] + pop_8_tab[x & u32(0xff)]) } // ones_count_64 returns the number of one bits ("population count") in x. +@[inline] pub fn ones_count_64(x u64) int { + return ones_count_64_default(x) +} + +fn ones_count_64_default(x u64) int { // Implementation: Parallel summing of adjacent bits. // See "Hacker's Delight", Chap. 5: Counting Bits. // The following pattern shows the general approach: @@ -376,7 +440,13 @@ const divide_error = 'Divide Error' // half returned in lo. // // This function's execution time does not depend on the inputs. +@[inline] pub fn mul_32(x u32, y u32) (u32, u32) { + return mul_32_default(x, y) +} + +@[inline] +fn mul_32_default(x u32, y u32) (u32, u32) { tmp := u64(x) * u64(y) hi := u32(tmp >> 32) lo := u32(tmp) @@ -388,7 +458,12 @@ pub fn mul_32(x u32, y u32) (u32, u32) { // half returned in lo. // // This function's execution time does not depend on the inputs. +@[inline] pub fn mul_64(x u64, y u64) (u64, u64) { + return mul_64_default(x, y) +} + +fn mul_64_default(x u64, y u64) (u64, u64) { x0 := x & mask32 x1 := x >> 32 y0 := y & mask32 @@ -403,12 +478,49 @@ pub fn mul_64(x u64, y u64) (u64, u64) { return hi, lo } +// mul_add_32 returns the 64-bit result of x * y + z: (hi, lo) = x * y + z +// with the result bits' upper half returned in hi and the lower +// half returned in lo. +@[inline] +pub fn mul_add_32(x u32, y u32, z u32) (u32, u32) { + return mul_add_32_default(x, y, z) +} + +@[inline] +fn mul_add_32_default(x u32, y u32, z u32) (u32, u32) { + tmp := u64(x) * u64(y) + u64(z) + hi := u32(tmp >> 32) + lo := u32(tmp) + return hi, lo +} + +// mul_add_64 returns the 128-bit result of x * y + z: (hi, lo) = x * y + z +// with the result bits' upper half returned in hi and the lower +// half returned in lo. +@[inline] +pub fn mul_add_64(x u64, y u64, z u64) (u64, u64) { + return mul_add_64_default(x, y, z) +} + +@[inline] +fn mul_add_64_default(x u64, y u64, z u64) (u64, u64) { + h, l := mul_64(x, y) + lo := l + z + hi := h + u64(lo < l) + return hi, lo +} + // --- Full-width divide --- // div_32 returns the quotient and remainder of (hi, lo) divided by y: // quo = (hi, lo)/y, rem = (hi, lo)%y with the dividend bits' upper // half in parameter hi and the lower half in parameter lo. // div_32 panics for y == 0 (division by zero) or y <= hi (quotient overflow). +@[inline] pub fn div_32(hi u32, lo u32, y u32) (u32, u32) { + return div_32_default(hi, lo, y) +} + +fn div_32_default(hi u32, lo u32, y u32) (u32, u32) { if y != 0 && y <= hi { panic(overflow_error) } @@ -422,7 +534,12 @@ pub fn div_32(hi u32, lo u32, y u32) (u32, u32) { // quo = (hi, lo)/y, rem = (hi, lo)%y with the dividend bits' upper // half in parameter hi and the lower half in parameter lo. // div_64 panics for y == 0 (division by zero) or y <= hi (quotient overflow). +@[inline] pub fn div_64(hi u64, lo u64, y1 u64) (u64, u64) { + return div_64_default(hi, lo, y1) +} + +fn div_64_default(hi u64, lo u64, y1 u64) (u64, u64) { mut y := y1 if y == 0 { panic(overflow_error) diff --git a/vlib/math/bits/bits_test.v b/vlib/math/bits/bits_test.v index c588bbaf9..78eed8a37 100644 --- a/vlib/math/bits/bits_test.v +++ b/vlib/math/bits/bits_test.v @@ -17,6 +17,7 @@ fn test_bits() { // C.printf("x:%02x lz: %d cmp: %d\n", i << x, leading_zeros_8(i << x), 7-x) assert leading_zeros_8(u8(u8(i) << x)) == 7 - x } + assert leading_zeros_8(0) == 8 // 16 bit i = 1 @@ -24,6 +25,7 @@ fn test_bits() { // C.printf("x:%04x lz: %d cmp: %d\n", u16(i) << x, leading_zeros_16(u16(i) << x), 15-x) assert leading_zeros_16(u16(i) << x) == 15 - x } + assert leading_zeros_16(0) == 16 // 32 bit i = 1 @@ -31,6 +33,7 @@ fn test_bits() { // C.printf("x:%08x lz: %d cmp: %d\n", u32(i) << x, leading_zeros_32(u32(i) << x), 31-x) assert leading_zeros_32(u32(i) << x) == 31 - x } + assert leading_zeros_32(0) == 32 // 64 bit i = 1 @@ -38,6 +41,39 @@ fn test_bits() { // C.printf("x:%016llx lz: %llu cmp: %d\n", u64(i) << x, leading_zeros_64(u64(i) << x), 63-x) assert leading_zeros_64(u64(i) << x) == 63 - x } + assert leading_zeros_64(0) == 64 + + // + // --- TrailingZeros --- + // + + // 8 bit + i = 1 + for x in 0 .. 8 { + assert trailing_zeros_8(u8(u8(i) << x)) == x + } + assert trailing_zeros_8(0) == 8 + + // 16 bit + i = 1 + for x in 0 .. 16 { + assert trailing_zeros_16(u16(i) << x) == x + } + assert trailing_zeros_16(0) == 16 + + // 32 bit + i = 1 + for x in 0 .. 32 { + assert trailing_zeros_32(u32(i) << x) == x + } + assert trailing_zeros_32(0) == 32 + + // 64 bit + i = 1 + for x in 0 .. 64 { + assert trailing_zeros_64(u64(i) << x) == x + } + assert trailing_zeros_64(0) == 64 // // --- ones_count --- @@ -50,6 +86,8 @@ fn test_bits() { assert ones_count_8(u8(i)) == x i = int(u32(i) << 1) + 1 } + assert ones_count_8(0) == 0 + assert ones_count_8(0xFF) == 8 // 16 bit i = 0 @@ -58,6 +96,8 @@ fn test_bits() { assert ones_count_16(u16(i)) == x i = int(u32(i) << 1) + 1 } + assert ones_count_16(0) == 0 + assert ones_count_16(0xFFFF) == 16 // 32 bit i = 0 @@ -66,6 +106,8 @@ fn test_bits() { assert ones_count_32(u32(i)) == x i = int(u32(i) << 1) + 1 } + assert ones_count_32(0) == 0 + assert ones_count_32(0xFFFF_FFFF) == 32 // 64 bit i1 = 0 @@ -74,6 +116,8 @@ fn test_bits() { assert ones_count_64(i1) == x i1 = (i1 << 1) + 1 } + assert ones_count_64(0) == 0 + assert ones_count_64(0xFFFF_FFFF_FFFF_FFFF) == 64 // // --- rotate_left/right --- @@ -241,6 +285,9 @@ fn test_bits() { v1 := v0 - 1 hi, lo := mul_32(v0, v1) assert (u64(hi) << 32) | (u64(lo)) == u64(v0) * u64(v1) + v2 := u32(x) + h, l := mul_add_32(v0, v1, v2) + assert (u64(h) << 32) | (u64(l)) == u64(v0) * u64(v1) + u64(v2) } // 64 bit @@ -252,6 +299,11 @@ fn test_bits() { // C.printf("v0: %llu v1: %llu [%llu,%llu] tt: %llu\n", v0, v1, hi, lo, (v0 >> 32) * (v1 >> 32)) assert (hi & 0xFFFF_FFFF_0000_0000) == (((v0 >> 32) * (v1 >> 32)) & 0xFFFF_FFFF_0000_0000) assert (lo & 0x0000_0000_FFFF_FFFF) == (((v0 & 0x0000_0000_FFFF_FFFF) * (v1 & 0x0000_0000_FFFF_FFFF)) & 0x0000_0000_FFFF_FFFF) + v2 := u64(x) + h, l := mul_add_64(v0, v1, v2) + assert (h & 0xFFFF_FFFF_0000_0000) == (((v0 >> 32) * (v1 >> 32)) & 0xFFFF_FFFF_0000_0000) + assert (l & 0x0000_0000_FFFF_FFFF) == (( + (v0 & 0x0000_0000_FFFF_FFFF) * (v1 & 0x0000_0000_FFFF_FFFF) + v2) & 0x0000_0000_FFFF_FFFF) } // -- 2.39.5