| 1 | // Copyright (c) 2023 Kim Shrier. All rights reserved. |
| 2 | // Use of this source code is governed by an MIT license |
| 3 | // that can be found in the LICENSE file. |
| 4 | // |
| 5 | // Package scrypt implements the key derivation functions as |
| 6 | // described in https://datatracker.ietf.org/doc/html/rfc7914 |
| 7 | module scrypt |
| 8 | |
| 9 | import crypto.pbkdf2 |
| 10 | import crypto.sha256 |
| 11 | import encoding.binary |
| 12 | import math.bits |
| 13 | |
| 14 | pub const max_buffer_length = ((u64(1) << 32) - 1) * 32 |
| 15 | pub const max_blocksize_parallal_product = u64(1 << 30) |
| 16 | |
| 17 | // salsa20_8 applies the salsa20/8 core transformation to a block |
| 18 | // of 64 u8 bytes. The block is modified in place. |
| 19 | fn salsa20_8(mut block []u8) { |
| 20 | // Keep the temporary state on the stack. |
| 21 | mut block_words := [16]u32{} |
| 22 | mut scratch := [16]u32{} |
| 23 | |
| 24 | for i in 0 .. 16 { |
| 25 | block_words[i] = binary.little_endian_u32_at(block, i * 4) |
| 26 | scratch[i] = block_words[i] |
| 27 | } |
| 28 | |
| 29 | for i := 8; i > 0; i -= 2 { |
| 30 | // processing columns |
| 31 | scratch[4] ^= bits.rotate_left_32(scratch[0] + scratch[12], 7) |
| 32 | scratch[8] ^= bits.rotate_left_32(scratch[4] + scratch[0], 9) |
| 33 | scratch[12] ^= bits.rotate_left_32(scratch[8] + scratch[4], 13) |
| 34 | scratch[0] ^= bits.rotate_left_32(scratch[12] + scratch[8], 18) |
| 35 | |
| 36 | scratch[9] ^= bits.rotate_left_32(scratch[5] + scratch[1], 7) |
| 37 | scratch[13] ^= bits.rotate_left_32(scratch[9] + scratch[5], 9) |
| 38 | scratch[1] ^= bits.rotate_left_32(scratch[13] + scratch[9], 13) |
| 39 | scratch[5] ^= bits.rotate_left_32(scratch[1] + scratch[13], 18) |
| 40 | |
| 41 | scratch[14] ^= bits.rotate_left_32(scratch[10] + scratch[6], 7) |
| 42 | scratch[2] ^= bits.rotate_left_32(scratch[14] + scratch[10], 9) |
| 43 | scratch[6] ^= bits.rotate_left_32(scratch[2] + scratch[14], 13) |
| 44 | scratch[10] ^= bits.rotate_left_32(scratch[6] + scratch[2], 18) |
| 45 | |
| 46 | scratch[3] ^= bits.rotate_left_32(scratch[15] + scratch[11], 7) |
| 47 | scratch[7] ^= bits.rotate_left_32(scratch[3] + scratch[15], 9) |
| 48 | scratch[11] ^= bits.rotate_left_32(scratch[7] + scratch[3], 13) |
| 49 | scratch[15] ^= bits.rotate_left_32(scratch[11] + scratch[7], 18) |
| 50 | |
| 51 | // processing rows |
| 52 | scratch[1] ^= bits.rotate_left_32(scratch[0] + scratch[3], 7) |
| 53 | scratch[2] ^= bits.rotate_left_32(scratch[1] + scratch[0], 9) |
| 54 | scratch[3] ^= bits.rotate_left_32(scratch[2] + scratch[1], 13) |
| 55 | scratch[0] ^= bits.rotate_left_32(scratch[3] + scratch[2], 18) |
| 56 | |
| 57 | scratch[6] ^= bits.rotate_left_32(scratch[5] + scratch[4], 7) |
| 58 | scratch[7] ^= bits.rotate_left_32(scratch[6] + scratch[5], 9) |
| 59 | scratch[4] ^= bits.rotate_left_32(scratch[7] + scratch[6], 13) |
| 60 | scratch[5] ^= bits.rotate_left_32(scratch[4] + scratch[7], 18) |
| 61 | |
| 62 | scratch[11] ^= bits.rotate_left_32(scratch[10] + scratch[9], 7) |
| 63 | scratch[8] ^= bits.rotate_left_32(scratch[11] + scratch[10], 9) |
| 64 | scratch[9] ^= bits.rotate_left_32(scratch[8] + scratch[11], 13) |
| 65 | scratch[10] ^= bits.rotate_left_32(scratch[9] + scratch[8], 18) |
| 66 | |
| 67 | scratch[12] ^= bits.rotate_left_32(scratch[15] + scratch[14], 7) |
| 68 | scratch[13] ^= bits.rotate_left_32(scratch[12] + scratch[15], 9) |
| 69 | scratch[14] ^= bits.rotate_left_32(scratch[13] + scratch[12], 13) |
| 70 | scratch[15] ^= bits.rotate_left_32(scratch[14] + scratch[13], 18) |
| 71 | } |
| 72 | |
| 73 | for i in 0 .. 16 { |
| 74 | scratch[i] += block_words[i] |
| 75 | binary.little_endian_put_u32_at(mut block, scratch[i], i * 4) |
| 76 | } |
| 77 | } |
| 78 | |
| 79 | @[inline] |
| 80 | fn blkcpy(mut dest []u8, src []u8, len u32) { |
| 81 | for i in 0 .. len { |
| 82 | dest[i] = src[i] |
| 83 | } |
| 84 | } |
| 85 | |
| 86 | @[inline] |
| 87 | fn blkxor(mut dest []u8, src []u8, len u32) { |
| 88 | for i in 0 .. len { |
| 89 | dest[i] ^= src[i] |
| 90 | } |
| 91 | } |
| 92 | |
| 93 | // block_mix performs the block_mix operation using salsa20_8 |
| 94 | // |
| 95 | // The block input must be 128 * r in length. The temp array |
| 96 | // has to be the same size, 128 * r. r is a positive integer |
| 97 | // value > 0. The block is modified in place. |
| 98 | fn block_mix(mut block []u8, mut temp []u8, r u32) { |
| 99 | mut scratch := [64]u8{} |
| 100 | // Reuse the fixed buffer directly instead of cloning it into a heap-backed slice. |
| 101 | mut scratch_buf := unsafe { scratch[..] } |
| 102 | |
| 103 | blkcpy(mut scratch_buf, block[(((2 * r) - 1) * 64)..], 64) |
| 104 | |
| 105 | for i in 0 .. 2 * r { |
| 106 | start := i * 64 |
| 107 | stop := start + 64 |
| 108 | |
| 109 | blkxor(mut scratch_buf, block[start..stop], 64) |
| 110 | salsa20_8(mut scratch_buf) |
| 111 | |
| 112 | blkcpy(mut temp[start..stop], scratch_buf, 64) |
| 113 | } |
| 114 | |
| 115 | for i in 0 .. r { |
| 116 | start := i * 64 |
| 117 | stop := start + 64 |
| 118 | |
| 119 | temp_start := (i * 2) * 64 |
| 120 | temp_stop := temp_start + 64 |
| 121 | |
| 122 | blkcpy(mut block[start..stop], temp[temp_start..temp_stop], 64) |
| 123 | } |
| 124 | |
| 125 | for i in 0 .. r { |
| 126 | start := (i + r) * 64 |
| 127 | stop := start + 64 |
| 128 | |
| 129 | temp_start := ((i * 2) + 1) * 64 |
| 130 | temp_stop := temp_start + 64 |
| 131 | |
| 132 | blkcpy(mut block[start..stop], temp[temp_start..temp_stop], 64) |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | fn smix(mut block []u8, r u32, n u64, mut v_block []u8, mut temp_block []u8) { |
| 137 | blkcpy(mut temp_block, block, 128 * r) |
| 138 | |
| 139 | y_start := 128 * r |
| 140 | |
| 141 | for i in 0 .. n { |
| 142 | v_start := i * (128 * r) |
| 143 | v_stop := v_start + (128 * r) |
| 144 | |
| 145 | blkcpy(mut v_block[int(v_start)..int(v_stop)], temp_block, 128 * r) |
| 146 | block_mix(mut temp_block, mut temp_block[y_start..], r) |
| 147 | } |
| 148 | |
| 149 | for _ in 0 .. n { |
| 150 | j := binary.little_endian_u64_at(temp_block, int(((2 * r) - 1) * 64)) & (n - 1) |
| 151 | |
| 152 | v_start := j * (128 * r) |
| 153 | v_stop := v_start + (128 * r) |
| 154 | |
| 155 | blkxor(mut temp_block, v_block[int(v_start)..int(v_stop)], 128 * r) |
| 156 | block_mix(mut temp_block, mut temp_block[y_start..], r) |
| 157 | } |
| 158 | |
| 159 | blkcpy(mut block, temp_block, 128 * r) |
| 160 | } |
| 161 | |
| 162 | struct OutputBufferLengthError { |
| 163 | Error |
| 164 | length u64 |
| 165 | } |
| 166 | |
| 167 | fn (err OutputBufferLengthError) msg() string { |
| 168 | return 'the output buffer length, ${err.length}, is greater than ${max_buffer_length}' |
| 169 | } |
| 170 | |
| 171 | struct BlocksizeParallelProductError { |
| 172 | Error |
| 173 | blocksize u32 |
| 174 | parallel u32 |
| 175 | product u64 |
| 176 | } |
| 177 | |
| 178 | fn (err BlocksizeParallelProductError) msg() string { |
| 179 | return 'the product of blocksize ${err.blocksize} * parallel ${err.parallel} = ${err.product}, is greater than ${max_blocksize_parallal_product}' |
| 180 | } |
| 181 | |
| 182 | struct CpuMemoryCostError { |
| 183 | Error |
| 184 | cost u64 |
| 185 | } |
| 186 | |
| 187 | fn (err CpuMemoryCostError) msg() string { |
| 188 | return 'the CPU/memory cost ${err.cost} must be greater than 0 and also a power of 2' |
| 189 | } |
| 190 | |
| 191 | // scrypt performs password based key derivation using the scrypt algorithm. |
| 192 | // |
| 193 | // The input parameters are: |
| 194 | // |
| 195 | // password - a slice of bytes which is the password being used to |
| 196 | // derive the key. Don't leak this value to anybody. |
| 197 | // salt - a slice of bytes used to make it harder to crack the key. |
| 198 | // n - CPU/Memory cost parameter, must be larger than 0, a power of 2, |
| 199 | // and less than 2^(128 * r / 8). |
| 200 | // r - block size parameter. |
| 201 | // p - parallelization parameter, a positive integer less than or |
| 202 | // equal to ((2^32-1) * hLen) / MFLen where hLen is 32 and |
| 203 | // MFlen is 128 * r. |
| 204 | // dk_len - intended output length in octets of the derived key; |
| 205 | // a positive integer less than or equal to (2^32 - 1) * hLen |
| 206 | // where hLen is 32. |
| 207 | // |
| 208 | // Reasonable values for n, r, and p are n = 1024, r = 8, p = 16. |
| 209 | pub fn scrypt(password []u8, salt []u8, n u64, r u32, p u32, dk_len u64) ![]u8 { |
| 210 | if dk_len > max_buffer_length { |
| 211 | return OutputBufferLengthError{ |
| 212 | length: dk_len |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | if u64(r) * u64(p) >= max_blocksize_parallal_product { |
| 217 | return BlocksizeParallelProductError{ |
| 218 | blocksize: r |
| 219 | parallel: p |
| 220 | product: u64(r) * u64(p) |
| 221 | } |
| 222 | } |
| 223 | |
| 224 | // the following is a sneaky way to determine if a number is a |
| 225 | // power of 2. Also, a value of 0 is not allowed. |
| 226 | if (n & (n - 1)) != 0 || n == 0 { |
| 227 | return CpuMemoryCostError{ |
| 228 | cost: n |
| 229 | } |
| 230 | } |
| 231 | |
| 232 | mut b := pbkdf2.key(password, salt, 1, int(128 * r * p), sha256.new())! |
| 233 | |
| 234 | mut xy := []u8{len: int(256 * r), cap: int(256 * r), init: 0} |
| 235 | mut v := []u8{len: int(128 * r * n), cap: int(128 * r * n), init: 0} |
| 236 | |
| 237 | for i in u32(0) .. p { |
| 238 | smix(mut b[i * 128 * r..], r, n, mut v, mut xy) |
| 239 | } |
| 240 | |
| 241 | result := pbkdf2.key(password, b, 1, int(128 * r * p), sha256.new())! |
| 242 | |
| 243 | return result[..int(dk_len)] |
| 244 | } |
| 245 | |