From 500477a8d0ccfdd7855d688f828677d510e2ad75 Mon Sep 17 00:00:00 2001 From: "Matthias J. Kannwischer" Date: Fri, 5 Jun 2026 11:35:48 +0800 Subject: [PATCH] x86_64: Replace rej_uniform_eta2/eta4 intrinsics with hand-written assembly Add hand-written x86_64 AVX2 assembly for rej_uniform_eta2 and rej_uniform_eta4, following the rej_uniform approach in #1014: the table is passed as a parameter and all constants are built from immediates (no .rodata), enabling future HOL-Light verification. Wire eta4 to the new asm in meta.h, add the asm entry points and contracts in arith_native_x86_64.h, register the bytecode dump targets in autogen and the Makefile, and add a poly_uniform_eta_4x component benchmark. Signed-off-by: jake massimo --- dev/x86_64/src/arith_native_x86_64.h | 32 ++ dev/x86_64/src/rej_uniform_eta2_avx2_asm.S | 282 ++++++++++++++++++ dev/x86_64/src/rej_uniform_eta4_avx2_asm.S | 226 ++++++++++++++ mldsa/src/native/x86_64/meta.h | 3 +- .../native/x86_64/src/arith_native_x86_64.h | 32 ++ .../x86_64/src/rej_uniform_eta2_avx2_asm.S | 282 ++++++++++++++++++ .../x86_64/src/rej_uniform_eta4_avx2_asm.S | 238 +++++++++++++++ proofs/hol_light/x86_64/Makefile | 8 +- .../x86_64/mldsa/rej_uniform_eta2_avx2_asm.S | 176 +++++++++++ .../x86_64/mldsa/rej_uniform_eta4_avx2_asm.S | 226 ++++++++++++++ scripts/autogen | 12 + test/bench/bench_components_mldsa.c | 10 + test/wycheproof/wycheproof_client.py | 4 +- 13 files changed, 1525 insertions(+), 6 deletions(-) create mode 100644 dev/x86_64/src/rej_uniform_eta2_avx2_asm.S create mode 100644 dev/x86_64/src/rej_uniform_eta4_avx2_asm.S create mode 100644 mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S create mode 100644 mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S create mode 100644 proofs/hol_light/x86_64/mldsa/rej_uniform_eta2_avx2_asm.S create mode 100644 proofs/hol_light/x86_64/mldsa/rej_uniform_eta4_avx2_asm.S diff --git a/dev/x86_64/src/arith_native_x86_64.h b/dev/x86_64/src/arith_native_x86_64.h index 6833ada9c..bbf68e396 100644 --- a/dev/x86_64/src/arith_native_x86_64.h +++ b/dev/x86_64/src/arith_native_x86_64.h @@ -85,10 +85,42 @@ MLD_MUST_CHECK_RETURN_VALUE unsigned mld_rej_uniform_eta2_avx2( int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN]); +#define mld_rej_uniform_eta2_avx2_asm MLD_NAMESPACE(rej_uniform_eta2_avx2_asm) +/* This contract must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/rej_uniform_eta2_avx2_asm.ml */ +MLD_MUST_CHECK_RETURN_VALUE +unsigned mld_rej_uniform_eta2_avx2_asm( + int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN], + const uint8_t *table) +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN)) + requires(table == mld_rej_uniform_table) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(return_value <= MLDSA_N) + ensures(array_bound(r, 0, return_value, -2, 2)) +); + #define mld_rej_uniform_eta4_avx2 MLD_NAMESPACE(mld_rej_uniform_eta4_avx2) MLD_MUST_CHECK_RETURN_VALUE unsigned mld_rej_uniform_eta4_avx2( int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN]); + +#define mld_rej_uniform_eta4_avx2_asm MLD_NAMESPACE(rej_uniform_eta4_avx2_asm) +/* This contract must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/rej_uniform_eta4_avx2_asm.ml */ +MLD_MUST_CHECK_RETURN_VALUE +unsigned mld_rej_uniform_eta4_avx2_asm( + int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN], + const uint8_t *table) +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN)) + requires(table == mld_rej_uniform_table) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(return_value <= MLDSA_N) + ensures(array_bound(r, 0, return_value, -4, 4)) +); #endif /* !MLD_CONFIG_NO_KEYPAIR_API */ #if !defined(MLD_CONFIG_NO_SIGN_API) diff --git a/dev/x86_64/src/rej_uniform_eta2_avx2_asm.S b/dev/x86_64/src/rej_uniform_eta2_avx2_asm.S new file mode 100644 index 000000000..0b8976592 --- /dev/null +++ b/dev/x86_64/src/rej_uniform_eta2_avx2_asm.S @@ -0,0 +1,282 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) +/* simpasm: header-end */ + +#define out %rdi +#define in %rsi +#define tab %rdx + +#define ctr %eax +#define pos %ecx + +#define good %r8d +#define cnt %r9d +#define tmp %r10 +#define val %r11d + +#define f0 %ymm0 +#define f1 %ymm1 +#define f2 %ymm2 +#define mask %ymm3 +#define eta %ymm4 +#define bound %ymm5 +#define v_const %ymm6 +#define p_const %ymm7 +#define g0 %xmm8 +#define g1 %xmm9 + + .text + +/* + * unsigned mld_rej_uniform_eta2_avx2_asm(int32_t *r, const uint8_t *buf, + * const uint8_t *table) + * + * Rejection sampling for ETA=2 polynomial coefficients. + * Extracts 4-bit nibbles from a byte buffer and accepts those < 15. + * Applies modulo-5 reduction: t = t - (205 * t >> 10) * 5 + * Output: coefficient = 2 - t, producing values in [-2, 2]. + * + * Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t) + * in (rsi): pointer to input byte buffer buf + * tab (rdx): pointer to rejection sampling lookup table (256x8 bytes) + * + * Returns: ctr (eax): number of valid coefficients written to r + */ + .balign 4 + .global MLD_ASM_NAMESPACE(rej_uniform_eta2_avx2_asm) +MLD_ASM_FN_SYMBOL(rej_uniform_eta2_avx2_asm) + +// Construct broadcast constants + movl $0x0F0F0F0F, good + vmovd good, %xmm3 + vpbroadcastd %xmm3, mask // mask: extract low 4 bits from each byte + + movl $0x02020202, good + vmovd good, %xmm4 + vpbroadcastd %xmm4, eta // eta: broadcast ETA=2 + + movl $0x0F0F0F0F, good + vmovd good, %xmm5 + vpbroadcastd %xmm5, bound // bound: rejection threshold (15) + + // Modulo-5 magic constants + // v = -6560 == 32*round(-2**10 / 5) for multiply-high-round-scale + movl $-6560, good + vpinsrw $0, good, %xmm6, %xmm6 + vpbroadcastw %xmm6, v_const // v_const: -6560 for mulhrs + + movl $5, good + vpinsrw $0, good, %xmm7, %xmm7 + vpbroadcastw %xmm7, p_const // p_const: 5 for mullo + +// Initialize counters + xorl ctr, ctr // ctr = 0 + xorl pos, pos // pos = 0 + +/* + * Main SIMD loop: process 16 input bytes into 32 nibbles, producing up to 8 + * coefficients per iteration (processing 4 groups of 8 nibbles each). + * Loops while ctr <= MLDSA_N - 8 and pos <= BUFLEN - 16. + */ +rej_uniform_eta2_avx2_asm_loop: + cmpl $248, ctr // MLDSA_N - 8 + ja rej_uniform_eta2_avx2_asm_scalar + cmpl $120, pos // MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN - 16 + ja rej_uniform_eta2_avx2_asm_scalar + + // Load 16 bytes and extract 32 nibbles into f0 + vpmovzxbw (in, %rcx), f0 // load 16 bytes, zero-extend to 16 words + vpsllw $4, f0, f1 // shift left by 4 to align high nibbles + vpor f1, f0, f0 // OR: each word now has nibble duplicated + vpand mask, f0, f0 // mask to get 4 bits per byte + + // Check which nibbles are < 15 + vpsubb bound, f0, f1 // f1 = nibble - 15 (negative if valid) + vpmovmskb f1, good // extract sign bits (valid = 1) + + // For valid nibbles, compute modulo-5 reduction and then eta - result + // First reduce nibble mod 5: t = nibble - (205 * nibble >> 10) * 5 + // Then output: 2 - t + + // Process first group of 8 nibbles (low 128 bits, low 64 bits of that) + vextracti128 $0, f0, g0 // extract low 128 bits + movzbl %r8b, %r10d // get low 8 bits of mask + vmovq (tab, tmp, 8), g1 // load shuffle indices from table + vpshufb g1, g0, g1 // compact valid nibbles + vpmovsxbd g1, f1 // sign-extend bytes to dwords + + // Apply modulo-5 reduction + vpmulhrsw v_const, f1, f2 // f2 = mulhrs(f1, -6560) + vpmullw p_const, f2, f2 // f2 = f2 * 5 + vpaddd f2, f1, f1 // f1 = f1 + f2 (reduces mod 5) + + // Compute eta - result = 2 - f1 + vpsubd f1, eta, f1 // f1 = 2 - f1 + + vmovdqu f1, (out, %rax, 4) // store 8 dwords + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr // ctr += popcount(low 8 bits) + shrl $8, good + addl $4, pos // consumed 4 input bytes + + cmpl $248, ctr + ja rej_uniform_eta2_avx2_asm_scalar + + // Process second group of 8 nibbles (low 128 bits, high 64 bits) + vpsrldq $8, g0, g0 // shift right to get next 8 nibbles + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vpmulhrsw v_const, f1, f2 + vpmullw p_const, f2, f2 + vpaddd f2, f1, f1 + vpsubd f1, eta, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + shrl $8, good + addl $4, pos + + cmpl $248, ctr + ja rej_uniform_eta2_avx2_asm_scalar + + // Process third group of 8 nibbles (high 128 bits, low 64 bits) + vextracti128 $1, f0, g0 // extract high 128 bits + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vpmulhrsw v_const, f1, f2 + vpmullw p_const, f2, f2 + vpaddd f2, f1, f1 + vpsubd f1, eta, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + shrl $8, good + addl $4, pos + + cmpl $248, ctr + ja rej_uniform_eta2_avx2_asm_scalar + + // Process fourth group of 8 nibbles (high 128 bits, high 64 bits) + vpsrldq $8, g0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vpmulhrsw v_const, f1, f2 + vpmullw p_const, f2, f2 + vpaddd f2, f1, f1 + vpsubd f1, eta, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + addl $4, pos + + jmp rej_uniform_eta2_avx2_asm_loop + +/* + * Scalar fallback loop: process remaining bytes one nibble at a time. + * Each byte contains two 4-bit nibbles (low and high). + * Apply modulo-5 reduction: t = t - (205 * t >> 10) * 5 + */ +rej_uniform_eta2_avx2_asm_scalar: + cmpl $256, ctr // MLDSA_N + jae rej_uniform_eta2_avx2_asm_done + cmpl $136, pos // MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN + jae rej_uniform_eta2_avx2_asm_done + + movzbl (in, %rcx), val // load 1 byte + incl pos + + // Process low nibble + movl val, %r10d + andl $0x0F, %r10d // extract low 4 bits + cmpl $15, %r10d + jae rej_uniform_eta2_avx2_asm_high_nibble + + // Apply modulo-5 reduction: t = t - (205 * t >> 10) * 5 + movl %r10d, val + imull $205, val + shrl $10, val + imull $5, val + subl val, %r10d // tmp = tmp - (205*tmp>>10)*5 + + movl $2, %r11d + subl %r10d, %r11d // 2 - tmp + movl %r11d, (out, %rax, 4) + incl ctr + + cmpl $256, ctr + jae rej_uniform_eta2_avx2_asm_done + +rej_uniform_eta2_avx2_asm_high_nibble: + // Reload original byte for high nibble + movzbl -1(in, %rcx), val // reload byte + shrl $4, val + andl $0x0F, val + cmpl $15, val + jae rej_uniform_eta2_avx2_asm_scalar + + // Apply modulo-5 reduction + movl val, %r10d + imull $205, %r10d + shrl $10, %r10d + imull $5, %r10d + subl %r10d, val // val = val - (205*val>>10)*5 + + movl $2, %r10d + subl val, %r10d // 2 - val + movl %r10d, (out, %rax, 4) + incl ctr + + jmp rej_uniform_eta2_avx2_asm_scalar + +rej_uniform_eta2_avx2_asm_done: + ret + +/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. + * Don't modify by hand -- this is auto-generated by scripts/autogen. */ +#undef out +#undef in +#undef tab +#undef ctr +#undef pos +#undef good +#undef cnt +#undef tmp +#undef val +#undef f0 +#undef f1 +#undef f2 +#undef mask +#undef eta +#undef bound +#undef v_const +#undef p_const +#undef g0 +#undef g1 + +/* simpasm: footer-start */ +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED */ diff --git a/dev/x86_64/src/rej_uniform_eta4_avx2_asm.S b/dev/x86_64/src/rej_uniform_eta4_avx2_asm.S new file mode 100644 index 000000000..4de978117 --- /dev/null +++ b/dev/x86_64/src/rej_uniform_eta4_avx2_asm.S @@ -0,0 +1,226 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) +/* simpasm: header-end */ + +#define out %rdi +#define in %rsi +#define tab %rdx + +#define ctr %eax +#define pos %ecx + +#define good %r8d +#define cnt %r9d +#define tmp %r10 +#define val %r11d + +#define f0 %ymm0 +#define f1 %ymm1 +#define mask %ymm2 +#define eta %ymm3 +#define bound %ymm4 +#define g0 %xmm5 +#define g1 %xmm6 + + .text + +/* + * unsigned mld_rej_uniform_eta4_avx2_asm(int32_t *r, const uint8_t *buf, + * const uint8_t *table) + * + * Rejection sampling for ETA=4 polynomial coefficients. + * Extracts 4-bit nibbles from a byte buffer and accepts those < 9. + * Output: coefficient = 4 - nibble, producing values in [-4, 4]. + * + * Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t) + * in (rsi): pointer to input byte buffer buf + * tab (rdx): pointer to rejection sampling lookup table (256x8 bytes) + * + * Returns: ctr (eax): number of valid coefficients written to r + */ + .balign 4 + .global MLD_ASM_NAMESPACE(rej_uniform_eta4_avx2_asm) +MLD_ASM_FN_SYMBOL(rej_uniform_eta4_avx2_asm) + + endbr64 + +// Construct broadcast constants + movl $0x0F0F0F0F, good + vmovd good, %xmm2 + vpbroadcastd %xmm2, mask // mask: extract low 4 bits from each byte + + movl $0x04040404, good + vmovd good, %xmm3 + vpbroadcastd %xmm3, eta // eta: broadcast ETA=4 + + movl $0x09090909, good + vmovd good, %xmm4 + vpbroadcastd %xmm4, bound // bound: rejection threshold (9) + +// Initialize counters + xorl ctr, ctr // ctr = 0 + xorl pos, pos // pos = 0 + +/* + * Main SIMD loop: process 16 input bytes into 32 nibbles, producing up to 32 + * coefficients per iteration (processing 4 groups of 8 nibbles each). + * + * Loop-head guards: ctr <= MLDSA_N - 32 = 224 and pos <= BUFLEN - 16 = 120. + * Threshold 224 chosen so that worst-case post-iter ctr <= 224 + 32 = 256, + * fitting in the 256-int32 output buffer without overshoot. + * + * Each iter runs all 4 sub-iters unconditionally — no mid-iter early exits. + * This makes the proof structurally simpler (mirrors the AArch64 design) + * while saving 3 fused-uops/iter and 3 highly-correlated mispredicted + * branches at the loop tail. + */ +rej_uniform_eta4_avx2_asm_loop: + cmpl $224, ctr // MLDSA_N - 32 + ja rej_uniform_eta4_avx2_asm_scalar + cmpl $120, pos // MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN - 16 + ja rej_uniform_eta4_avx2_asm_scalar + + // Load 16 bytes and extract 32 nibbles into f0 + vpmovzxbw (in, %rcx), f0 // load 16 bytes, zero-extend to 16 words + vpsllw $4, f0, f1 // shift left by 4 to align high nibbles + vpor f1, f0, f0 // OR: each word now has nibble duplicated + vpand mask, f0, f0 // mask to get 4 bits per byte + + // Check which nibbles are < 9 and compute eta - nibble + vpsubb bound, f0, f1 // f1 = nibble - 9 (negative if valid) + vpsubb f0, eta, f0 // f0 = 4 - nibble + vpmovmskb f1, good // extract sign bits (valid = 1) + + // Process first group of 8 nibbles (low 128 bits, low 64 bits of that) + vextracti128 $0, f0, g0 // extract low 128 bits + movzbl %r8b, %r10d // get low 8 bits of mask + vmovq (tab, tmp, 8), g1 // load shuffle indices from table + vpshufb g1, g0, g1 // compact valid nibbles + vpmovsxbd g1, f1 // sign-extend bytes to dwords + vmovdqu f1, (out, %rax, 4) // store 8 dwords + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr // ctr += popcount(low 8 bits) + shrl $8, good + addl $4, pos // consumed 4 input bytes + + // Process second group of 8 nibbles (low 128 bits, high 64 bits) + vpsrldq $8, g0, g0 // shift right to get next 8 nibbles + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr + shrl $8, good + addl $4, pos + + // Process third group of 8 nibbles (high 128 bits, low 64 bits) + vextracti128 $1, f0, g0 // extract high 128 bits + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr + shrl $8, good + addl $4, pos + + // Process fourth group of 8 nibbles (high 128 bits, high 64 bits) + vpsrldq $8, g0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr + addl $4, pos + + jmp rej_uniform_eta4_avx2_asm_loop + +/* + * Scalar fallback loop: process remaining bytes one nibble at a time. + * Each byte contains two 4-bit nibbles (low and high). + */ +rej_uniform_eta4_avx2_asm_scalar: + cmpl $256, ctr // MLDSA_N + jae rej_uniform_eta4_avx2_asm_done + cmpl $136, pos // MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN + jae rej_uniform_eta4_avx2_asm_done + + movzbl (in, %rcx), val // load 1 byte + incl pos + + // Process low nibble + movl val, %r10d + andl $0x0F, %r10d // extract low 4 bits + cmpl $9, %r10d + jae rej_uniform_eta4_avx2_asm_high_nibble + + movl $4, %r11d + subl %r10d, %r11d // 4 - nibble + movl %r11d, (out, %rax, 4) + incl ctr + + cmpl $256, ctr + jae rej_uniform_eta4_avx2_asm_done + +rej_uniform_eta4_avx2_asm_high_nibble: + // Process high nibble + shrl $4, val + andl $0x0F, val + cmpl $9, val + jae rej_uniform_eta4_avx2_asm_scalar + + movl $4, %r10d + subl val, %r10d // 4 - nibble + movl %r10d, (out, %rax, 4) + incl ctr + + jmp rej_uniform_eta4_avx2_asm_scalar + +rej_uniform_eta4_avx2_asm_done: + ret + +/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. + * Don't modify by hand -- this is auto-generated by scripts/autogen. */ +#undef out +#undef in +#undef tab +#undef ctr +#undef pos +#undef good +#undef cnt +#undef tmp +#undef val +#undef f0 +#undef f1 +#undef mask +#undef eta +#undef bound +#undef g0 +#undef g1 + +/* simpasm: footer-start */ +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED */ diff --git a/mldsa/src/native/x86_64/meta.h b/mldsa/src/native/x86_64/meta.h index 4d37bbc01..f9eefb2ef 100644 --- a/mldsa/src/native/x86_64/meta.h +++ b/mldsa/src/native/x86_64/meta.h @@ -135,7 +135,8 @@ static MLD_INLINE int mld_rej_uniform_eta4_native(int32_t *r, unsigned len, * We declassify prior the input data and mark the outputs as secret. */ MLD_CT_TESTING_DECLASSIFY(buf, buflen); - outlen = mld_rej_uniform_eta4_avx2(r, buf); + outlen = mld_rej_uniform_eta4_avx2_asm(r, buf, + (const uint8_t *)mld_rej_uniform_table); MLD_CT_TESTING_SECRET(r, sizeof(int32_t) * outlen); /* Safety: outlen is at most MLDSA_N and, hence, this cast is safe. */ return (int)outlen; diff --git a/mldsa/src/native/x86_64/src/arith_native_x86_64.h b/mldsa/src/native/x86_64/src/arith_native_x86_64.h index 6833ada9c..bbf68e396 100644 --- a/mldsa/src/native/x86_64/src/arith_native_x86_64.h +++ b/mldsa/src/native/x86_64/src/arith_native_x86_64.h @@ -85,10 +85,42 @@ MLD_MUST_CHECK_RETURN_VALUE unsigned mld_rej_uniform_eta2_avx2( int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN]); +#define mld_rej_uniform_eta2_avx2_asm MLD_NAMESPACE(rej_uniform_eta2_avx2_asm) +/* This contract must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/rej_uniform_eta2_avx2_asm.ml */ +MLD_MUST_CHECK_RETURN_VALUE +unsigned mld_rej_uniform_eta2_avx2_asm( + int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN], + const uint8_t *table) +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN)) + requires(table == mld_rej_uniform_table) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(return_value <= MLDSA_N) + ensures(array_bound(r, 0, return_value, -2, 2)) +); + #define mld_rej_uniform_eta4_avx2 MLD_NAMESPACE(mld_rej_uniform_eta4_avx2) MLD_MUST_CHECK_RETURN_VALUE unsigned mld_rej_uniform_eta4_avx2( int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN]); + +#define mld_rej_uniform_eta4_avx2_asm MLD_NAMESPACE(rej_uniform_eta4_avx2_asm) +/* This contract must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/rej_uniform_eta4_avx2_asm.ml */ +MLD_MUST_CHECK_RETURN_VALUE +unsigned mld_rej_uniform_eta4_avx2_asm( + int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN], + const uint8_t *table) +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN)) + requires(table == mld_rej_uniform_table) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(return_value <= MLDSA_N) + ensures(array_bound(r, 0, return_value, -4, 4)) +); #endif /* !MLD_CONFIG_NO_KEYPAIR_API */ #if !defined(MLD_CONFIG_NO_SIGN_API) diff --git a/mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S b/mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S new file mode 100644 index 000000000..0b8976592 --- /dev/null +++ b/mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S @@ -0,0 +1,282 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) +/* simpasm: header-end */ + +#define out %rdi +#define in %rsi +#define tab %rdx + +#define ctr %eax +#define pos %ecx + +#define good %r8d +#define cnt %r9d +#define tmp %r10 +#define val %r11d + +#define f0 %ymm0 +#define f1 %ymm1 +#define f2 %ymm2 +#define mask %ymm3 +#define eta %ymm4 +#define bound %ymm5 +#define v_const %ymm6 +#define p_const %ymm7 +#define g0 %xmm8 +#define g1 %xmm9 + + .text + +/* + * unsigned mld_rej_uniform_eta2_avx2_asm(int32_t *r, const uint8_t *buf, + * const uint8_t *table) + * + * Rejection sampling for ETA=2 polynomial coefficients. + * Extracts 4-bit nibbles from a byte buffer and accepts those < 15. + * Applies modulo-5 reduction: t = t - (205 * t >> 10) * 5 + * Output: coefficient = 2 - t, producing values in [-2, 2]. + * + * Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t) + * in (rsi): pointer to input byte buffer buf + * tab (rdx): pointer to rejection sampling lookup table (256x8 bytes) + * + * Returns: ctr (eax): number of valid coefficients written to r + */ + .balign 4 + .global MLD_ASM_NAMESPACE(rej_uniform_eta2_avx2_asm) +MLD_ASM_FN_SYMBOL(rej_uniform_eta2_avx2_asm) + +// Construct broadcast constants + movl $0x0F0F0F0F, good + vmovd good, %xmm3 + vpbroadcastd %xmm3, mask // mask: extract low 4 bits from each byte + + movl $0x02020202, good + vmovd good, %xmm4 + vpbroadcastd %xmm4, eta // eta: broadcast ETA=2 + + movl $0x0F0F0F0F, good + vmovd good, %xmm5 + vpbroadcastd %xmm5, bound // bound: rejection threshold (15) + + // Modulo-5 magic constants + // v = -6560 == 32*round(-2**10 / 5) for multiply-high-round-scale + movl $-6560, good + vpinsrw $0, good, %xmm6, %xmm6 + vpbroadcastw %xmm6, v_const // v_const: -6560 for mulhrs + + movl $5, good + vpinsrw $0, good, %xmm7, %xmm7 + vpbroadcastw %xmm7, p_const // p_const: 5 for mullo + +// Initialize counters + xorl ctr, ctr // ctr = 0 + xorl pos, pos // pos = 0 + +/* + * Main SIMD loop: process 16 input bytes into 32 nibbles, producing up to 8 + * coefficients per iteration (processing 4 groups of 8 nibbles each). + * Loops while ctr <= MLDSA_N - 8 and pos <= BUFLEN - 16. + */ +rej_uniform_eta2_avx2_asm_loop: + cmpl $248, ctr // MLDSA_N - 8 + ja rej_uniform_eta2_avx2_asm_scalar + cmpl $120, pos // MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN - 16 + ja rej_uniform_eta2_avx2_asm_scalar + + // Load 16 bytes and extract 32 nibbles into f0 + vpmovzxbw (in, %rcx), f0 // load 16 bytes, zero-extend to 16 words + vpsllw $4, f0, f1 // shift left by 4 to align high nibbles + vpor f1, f0, f0 // OR: each word now has nibble duplicated + vpand mask, f0, f0 // mask to get 4 bits per byte + + // Check which nibbles are < 15 + vpsubb bound, f0, f1 // f1 = nibble - 15 (negative if valid) + vpmovmskb f1, good // extract sign bits (valid = 1) + + // For valid nibbles, compute modulo-5 reduction and then eta - result + // First reduce nibble mod 5: t = nibble - (205 * nibble >> 10) * 5 + // Then output: 2 - t + + // Process first group of 8 nibbles (low 128 bits, low 64 bits of that) + vextracti128 $0, f0, g0 // extract low 128 bits + movzbl %r8b, %r10d // get low 8 bits of mask + vmovq (tab, tmp, 8), g1 // load shuffle indices from table + vpshufb g1, g0, g1 // compact valid nibbles + vpmovsxbd g1, f1 // sign-extend bytes to dwords + + // Apply modulo-5 reduction + vpmulhrsw v_const, f1, f2 // f2 = mulhrs(f1, -6560) + vpmullw p_const, f2, f2 // f2 = f2 * 5 + vpaddd f2, f1, f1 // f1 = f1 + f2 (reduces mod 5) + + // Compute eta - result = 2 - f1 + vpsubd f1, eta, f1 // f1 = 2 - f1 + + vmovdqu f1, (out, %rax, 4) // store 8 dwords + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr // ctr += popcount(low 8 bits) + shrl $8, good + addl $4, pos // consumed 4 input bytes + + cmpl $248, ctr + ja rej_uniform_eta2_avx2_asm_scalar + + // Process second group of 8 nibbles (low 128 bits, high 64 bits) + vpsrldq $8, g0, g0 // shift right to get next 8 nibbles + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vpmulhrsw v_const, f1, f2 + vpmullw p_const, f2, f2 + vpaddd f2, f1, f1 + vpsubd f1, eta, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + shrl $8, good + addl $4, pos + + cmpl $248, ctr + ja rej_uniform_eta2_avx2_asm_scalar + + // Process third group of 8 nibbles (high 128 bits, low 64 bits) + vextracti128 $1, f0, g0 // extract high 128 bits + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vpmulhrsw v_const, f1, f2 + vpmullw p_const, f2, f2 + vpaddd f2, f1, f1 + vpsubd f1, eta, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + shrl $8, good + addl $4, pos + + cmpl $248, ctr + ja rej_uniform_eta2_avx2_asm_scalar + + // Process fourth group of 8 nibbles (high 128 bits, high 64 bits) + vpsrldq $8, g0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vpmulhrsw v_const, f1, f2 + vpmullw p_const, f2, f2 + vpaddd f2, f1, f1 + vpsubd f1, eta, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + addl $4, pos + + jmp rej_uniform_eta2_avx2_asm_loop + +/* + * Scalar fallback loop: process remaining bytes one nibble at a time. + * Each byte contains two 4-bit nibbles (low and high). + * Apply modulo-5 reduction: t = t - (205 * t >> 10) * 5 + */ +rej_uniform_eta2_avx2_asm_scalar: + cmpl $256, ctr // MLDSA_N + jae rej_uniform_eta2_avx2_asm_done + cmpl $136, pos // MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN + jae rej_uniform_eta2_avx2_asm_done + + movzbl (in, %rcx), val // load 1 byte + incl pos + + // Process low nibble + movl val, %r10d + andl $0x0F, %r10d // extract low 4 bits + cmpl $15, %r10d + jae rej_uniform_eta2_avx2_asm_high_nibble + + // Apply modulo-5 reduction: t = t - (205 * t >> 10) * 5 + movl %r10d, val + imull $205, val + shrl $10, val + imull $5, val + subl val, %r10d // tmp = tmp - (205*tmp>>10)*5 + + movl $2, %r11d + subl %r10d, %r11d // 2 - tmp + movl %r11d, (out, %rax, 4) + incl ctr + + cmpl $256, ctr + jae rej_uniform_eta2_avx2_asm_done + +rej_uniform_eta2_avx2_asm_high_nibble: + // Reload original byte for high nibble + movzbl -1(in, %rcx), val // reload byte + shrl $4, val + andl $0x0F, val + cmpl $15, val + jae rej_uniform_eta2_avx2_asm_scalar + + // Apply modulo-5 reduction + movl val, %r10d + imull $205, %r10d + shrl $10, %r10d + imull $5, %r10d + subl %r10d, val // val = val - (205*val>>10)*5 + + movl $2, %r10d + subl val, %r10d // 2 - val + movl %r10d, (out, %rax, 4) + incl ctr + + jmp rej_uniform_eta2_avx2_asm_scalar + +rej_uniform_eta2_avx2_asm_done: + ret + +/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. + * Don't modify by hand -- this is auto-generated by scripts/autogen. */ +#undef out +#undef in +#undef tab +#undef ctr +#undef pos +#undef good +#undef cnt +#undef tmp +#undef val +#undef f0 +#undef f1 +#undef f2 +#undef mask +#undef eta +#undef bound +#undef v_const +#undef p_const +#undef g0 +#undef g1 + +/* simpasm: footer-start */ +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED */ diff --git a/mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S b/mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S new file mode 100644 index 000000000..89dd7f56f --- /dev/null +++ b/mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S @@ -0,0 +1,238 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_NO_KEYPAIR_API) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ + (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || MLDSA_ETA == 4) +/* simpasm: header-end */ + +#define out %rdi +#define in %rsi +#define tab %rdx + +#define ctr %eax +#define pos %ecx + +#define good %r8d +#define cnt %r9d +#define tmp %r10 +#define val %r11d + +#define f0 %ymm0 +#define f1 %ymm1 +#define mask %ymm2 +#define eta %ymm3 +#define bound %ymm4 +#define g0 %xmm5 +#define g1 %xmm6 + + .text + +/* + * unsigned mld_rej_uniform_eta4_avx2_asm(int32_t *r, const uint8_t *buf, + * const uint8_t *table) + * + * Rejection sampling for ETA=4 polynomial coefficients. + * Extracts 4-bit nibbles from a byte buffer and accepts those < 9. + * Output: coefficient = 4 - nibble, producing values in [-4, 4]. + * + * Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t) + * in (rsi): pointer to input byte buffer buf (272 bytes) + * tab (rdx): pointer to rejection sampling lookup table (256x8 bytes) + * + * Returns: ctr (eax): number of valid coefficients written to r + */ + .balign 4 + .global MLD_ASM_NAMESPACE(rej_uniform_eta4_avx2_asm) +MLD_ASM_FN_SYMBOL(rej_uniform_eta4_avx2_asm) + + endbr64 + +// Construct broadcast constants + movl $0x0F0F0F0F, good + vmovd good, %xmm2 + vpbroadcastd %xmm2, mask // mask: extract low 4 bits from each byte + + movl $0x04040404, good + vmovd good, %xmm3 + vpbroadcastd %xmm3, eta // eta: broadcast ETA=4 + + movl $0x09090909, good + vmovd good, %xmm4 + vpbroadcastd %xmm4, bound // bound: rejection threshold (9) + +// Initialize counters + xorl ctr, ctr // ctr = 0 + xorl pos, pos // pos = 0 + +/* + * Main SIMD loop: process 16 input bytes into 32 nibbles, producing up to + * 32 coefficients per iteration (4 sub-iterations of 8 nibbles each). + * + * Loop-head guards: ctr <= MLDSA_N - 8 = 248 and pos <= BUFLEN - 16 = 256. + * + * Mid-iter early exits at ctr > 248 prevent buffer overshoot: each sub-iter + * stores 8 ints starting at r[ctr], and ctr advances by popcount(<= 8). With + * ctr <= 248 entering a sub-iter, the store touches r[248..256] — exactly + * fitting the 256-int output buffer. + */ +rej_uniform_eta4_avx2_asm_loop: + cmpl $248, ctr // MLDSA_N - 8 + ja rej_uniform_eta4_avx2_asm_scalar + cmpl $256, pos // MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN - 16 + ja rej_uniform_eta4_avx2_asm_scalar + + // Load 16 bytes and extract 32 nibbles into f0 + vpmovzxbw (in, %rcx), f0 // load 16 bytes, zero-extend to 16 words + vpsllw $4, f0, f1 // shift left by 4 to align high nibbles + vpor f1, f0, f0 // OR: each word now has nibble duplicated + vpand mask, f0, f0 // mask to get 4 bits per byte + + // Check which nibbles are < 9 and compute eta - nibble + vpsubb bound, f0, f1 // f1 = nibble - 9 (negative if valid) + vpsubb f0, eta, f0 // f0 = 4 - nibble + vpmovmskb f1, good // extract sign bits (valid = 1) + + // Sub-iter 1: extract low 128 bits of f0; process bits 0..7 of mask + vextracti128 $0, f0, g0 + movzbl %r8b, %r10d // tmp = good & 0xFF + vmovq (tab, tmp, 8), g1 // g1 = table[good & 0xFF] (8 byte indices) + vpshufb g1, g0, g1 // compact accepted nibbles to front + vpmovsxbd g1, f1 // sign-extend 8 bytes -> 8 int32 + vmovdqu f1, (out, %rax, 4) // store 8 ints at r[ctr] + popcntl %r10d, cnt // cnt = popcount(good & 0xFF) + addl cnt, ctr // ctr += cnt + shrl $8, good // shift good for next sub-iter + addl $4, pos // 4 input bytes consumed + + cmpl $248, ctr // mid-iter exit if ctr > 248 + ja rej_uniform_eta4_avx2_asm_scalar + + // Sub-iter 2: shift xmm5 down by 8 bytes; process next 8 bits of mask + vpsrldq $8, g0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + shrl $8, good + addl $4, pos + + cmpl $248, ctr + ja rej_uniform_eta4_avx2_asm_scalar + + // Sub-iter 3: extract high 128 bits of f0 + vextracti128 $1, f0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + shrl $8, good + addl $4, pos + + cmpl $248, ctr + ja rej_uniform_eta4_avx2_asm_scalar + + // Sub-iter 4: shift xmm5 down by 8 bytes; process final 8 bits of mask + vpsrldq $8, g0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + addl $4, pos + + jmp rej_uniform_eta4_avx2_asm_loop + +/* + * Scalar fallback loop: process remaining bytes one nibble at a time. + * Each byte contains two 4-bit nibbles (low and high), each accepted iff < 9. + */ +rej_uniform_eta4_avx2_asm_scalar: + cmpl $256, ctr // MLDSA_N + jae rej_uniform_eta4_avx2_asm_done + cmpl $272, pos // MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN + jae rej_uniform_eta4_avx2_asm_done + + movzbl (in, %rcx), val // load 1 byte + incl pos + + // Process low nibble + movl val, %r10d + andl $0x0F, %r10d // extract low 4 bits + cmpl $9, %r10d + jae rej_uniform_eta4_avx2_asm_high_nibble + + movl $4, cnt + subl %r10d, cnt // 4 - nibble + movl cnt, (out, %rax, 4) + incl ctr + + cmpl $256, ctr + jae rej_uniform_eta4_avx2_asm_done + +rej_uniform_eta4_avx2_asm_high_nibble: + // Process high nibble + shrl $4, val + andl $0x0F, val + cmpl $9, val + jae rej_uniform_eta4_avx2_asm_scalar + + movl $4, %r10d + subl val, %r10d // 4 - nibble + movl %r10d, (out, %rax, 4) + incl ctr + + jmp rej_uniform_eta4_avx2_asm_scalar + +rej_uniform_eta4_avx2_asm_done: + ret + +/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. + * Don't modify by hand -- this is auto-generated by scripts/autogen. */ +#undef out +#undef in +#undef tab +#undef ctr +#undef pos +#undef good +#undef cnt +#undef tmp +#undef val +#undef f0 +#undef f1 +#undef mask +#undef eta +#undef bound +#undef g0 +#undef g1 + +/* simpasm: footer-start */ +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_KEYPAIR_API && \ + !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ + (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLDSA_ETA == 4) */ diff --git a/proofs/hol_light/x86_64/Makefile b/proofs/hol_light/x86_64/Makefile index ee9506698..6fb83d23e 100644 --- a/proofs/hol_light/x86_64/Makefile +++ b/proofs/hol_light/x86_64/Makefile @@ -40,9 +40,9 @@ OBJDUMP=$(CROSS_PREFIX)objdump -d # by single-quote characters in comments, so we eliminate // comments first. ifeq ($(OSTYPE_RESULT),Darwin) -PREPROCESS=sed -e 's/\/\/.*//' | $(CC) -E -xassembler-with-cpp - +PREPROCESS=sed -e 's/\/\/.*//' | $(CC) -E -xassembler-with-cpp -I$(BASE)/../../../mldsa -I$(BASE)/../../../mldsa/src -I$(BASE)/../../../common -DMLD_CONFIG_PARAMETER_SET=65 - else -PREPROCESS=$(CC) -E -xassembler-with-cpp - +PREPROCESS=$(CC) -E -xassembler-with-cpp -I$(BASE)/../../../mldsa -I$(BASE)/../../../mldsa/src -I$(BASE)/../../../common -DMLD_CONFIG_PARAMETER_SET=65 - endif # Generally GNU-type assemblers are happy with multiple instructions on @@ -58,7 +58,9 @@ OBJ = mldsa/ntt_avx2_asm.o \ mldsa/pointwise_acc_l4_avx2_asm.o \ mldsa/pointwise_acc_l5_avx2_asm.o \ mldsa/pointwise_acc_l7_avx2_asm.o \ - mldsa/keccak_f1600_x4_avx2_asm.o + mldsa/keccak_f1600_x4_avx2_asm.o \ + mldsa/rej_uniform_eta4_avx2_asm.o \ + mldsa/rej_uniform_eta2_avx2_asm.o # Build object files from assembly sources $(OBJ): %.o : %.S diff --git a/proofs/hol_light/x86_64/mldsa/rej_uniform_eta2_avx2_asm.S b/proofs/hol_light/x86_64/mldsa/rej_uniform_eta2_avx2_asm.S new file mode 100644 index 000000000..704fb5d16 --- /dev/null +++ b/proofs/hol_light/x86_64/mldsa/rej_uniform_eta2_avx2_asm.S @@ -0,0 +1,176 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/x86_64/src/rej_uniform_eta2_avx2_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +#ifdef __APPLE__ +.global _PQCP_MLDSA_NATIVE_MLDSA44_rej_uniform_eta2_avx2_asm +_PQCP_MLDSA_NATIVE_MLDSA44_rej_uniform_eta2_avx2_asm: +#else +.global PQCP_MLDSA_NATIVE_MLDSA44_rej_uniform_eta2_avx2_asm +PQCP_MLDSA_NATIVE_MLDSA44_rej_uniform_eta2_avx2_asm: +#endif + + .cfi_startproc + endbr64 + movl $0xf0f0f0f, %r8d # imm = 0xF0F0F0F + vmovd %r8d, %xmm3 + vpbroadcastd %xmm3, %ymm3 + movl $0x2020202, %r8d # imm = 0x2020202 + vmovd %r8d, %xmm4 + vpbroadcastd %xmm4, %ymm4 + movl $0xf0f0f0f, %r8d # imm = 0xF0F0F0F + vmovd %r8d, %xmm5 + vpbroadcastd %xmm5, %ymm5 + movl $0xffffe660, %r8d # imm = 0xFFFFE660 + vpinsrw $0x0, %r8d, %xmm6, %xmm6 + vpbroadcastw %xmm6, %ymm6 + movl $0x5, %r8d + vpinsrw $0x0, %r8d, %xmm7, %xmm7 + vpbroadcastw %xmm7, %ymm7 + xorl %eax, %eax + xorl %ecx, %ecx + +Lrej_uniform_eta2_avx2_asm_loop: + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + cmpl $0x78, %ecx + ja Lrej_uniform_eta2_avx2_asm_scalar + vpmovzxbw (%rsi,%rcx), %ymm0 # ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero + vpsllw $0x4, %ymm0, %ymm1 + vpor %ymm1, %ymm0, %ymm0 + vpand %ymm3, %ymm0, %ymm0 + vpsubb %ymm5, %ymm0, %ymm1 + vpmovmskb %ymm1, %r8d + vextracti128 $0x0, %ymm0, %xmm8 + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vpsubd %ymm1, %ymm4, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + vpsrldq $0x8, %xmm8, %xmm8 # xmm8 = xmm8[8,9,10,11,12,13,14,15],zero,zero,zero,zero,zero,zero,zero,zero + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vpsubd %ymm1, %ymm4, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + vextracti128 $0x1, %ymm0, %xmm8 + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vpsubd %ymm1, %ymm4, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + vpsrldq $0x8, %xmm8, %xmm8 # xmm8 = xmm8[8,9,10,11,12,13,14,15],zero,zero,zero,zero,zero,zero,zero,zero + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vpsubd %ymm1, %ymm4, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + addl $0x4, %ecx + jmp Lrej_uniform_eta2_avx2_asm_loop + +Lrej_uniform_eta2_avx2_asm_scalar: + cmpl $0x100, %eax # imm = 0x100 + jae Lrej_uniform_eta2_avx2_asm_done + cmpl $0x88, %ecx + jae Lrej_uniform_eta2_avx2_asm_done + movzbl (%rsi,%rcx), %r11d + incl %ecx + movl %r11d, %r10d + andl $0xf, %r10d + cmpl $0xf, %r10d + jae Lrej_uniform_eta2_avx2_asm_high_nibble + movl %r10d, %r11d + imull $0xcd, %r11d, %r11d + shrl $0xa, %r11d + imull $0x5, %r11d, %r11d + subl %r11d, %r10d + movl $0x2, %r11d + subl %r10d, %r11d + movl %r11d, (%rdi,%rax,4) + incl %eax + cmpl $0x100, %eax # imm = 0x100 + jae Lrej_uniform_eta2_avx2_asm_done + +Lrej_uniform_eta2_avx2_asm_high_nibble: + movzbl -0x1(%rsi,%rcx), %r11d + shrl $0x4, %r11d + andl $0xf, %r11d + cmpl $0xf, %r11d + jae Lrej_uniform_eta2_avx2_asm_scalar + movl %r11d, %r10d + imull $0xcd, %r10d, %r10d + shrl $0xa, %r10d + imull $0x5, %r10d, %r10d + subl %r10d, %r11d + movl $0x2, %r10d + subl %r11d, %r10d + movl %r10d, (%rdi,%rax,4) + incl %eax + jmp Lrej_uniform_eta2_avx2_asm_scalar + +Lrej_uniform_eta2_avx2_asm_done: + retq + .cfi_endproc + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/proofs/hol_light/x86_64/mldsa/rej_uniform_eta4_avx2_asm.S b/proofs/hol_light/x86_64/mldsa/rej_uniform_eta4_avx2_asm.S new file mode 100644 index 000000000..4de978117 --- /dev/null +++ b/proofs/hol_light/x86_64/mldsa/rej_uniform_eta4_avx2_asm.S @@ -0,0 +1,226 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) +/* simpasm: header-end */ + +#define out %rdi +#define in %rsi +#define tab %rdx + +#define ctr %eax +#define pos %ecx + +#define good %r8d +#define cnt %r9d +#define tmp %r10 +#define val %r11d + +#define f0 %ymm0 +#define f1 %ymm1 +#define mask %ymm2 +#define eta %ymm3 +#define bound %ymm4 +#define g0 %xmm5 +#define g1 %xmm6 + + .text + +/* + * unsigned mld_rej_uniform_eta4_avx2_asm(int32_t *r, const uint8_t *buf, + * const uint8_t *table) + * + * Rejection sampling for ETA=4 polynomial coefficients. + * Extracts 4-bit nibbles from a byte buffer and accepts those < 9. + * Output: coefficient = 4 - nibble, producing values in [-4, 4]. + * + * Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t) + * in (rsi): pointer to input byte buffer buf + * tab (rdx): pointer to rejection sampling lookup table (256x8 bytes) + * + * Returns: ctr (eax): number of valid coefficients written to r + */ + .balign 4 + .global MLD_ASM_NAMESPACE(rej_uniform_eta4_avx2_asm) +MLD_ASM_FN_SYMBOL(rej_uniform_eta4_avx2_asm) + + endbr64 + +// Construct broadcast constants + movl $0x0F0F0F0F, good + vmovd good, %xmm2 + vpbroadcastd %xmm2, mask // mask: extract low 4 bits from each byte + + movl $0x04040404, good + vmovd good, %xmm3 + vpbroadcastd %xmm3, eta // eta: broadcast ETA=4 + + movl $0x09090909, good + vmovd good, %xmm4 + vpbroadcastd %xmm4, bound // bound: rejection threshold (9) + +// Initialize counters + xorl ctr, ctr // ctr = 0 + xorl pos, pos // pos = 0 + +/* + * Main SIMD loop: process 16 input bytes into 32 nibbles, producing up to 32 + * coefficients per iteration (processing 4 groups of 8 nibbles each). + * + * Loop-head guards: ctr <= MLDSA_N - 32 = 224 and pos <= BUFLEN - 16 = 120. + * Threshold 224 chosen so that worst-case post-iter ctr <= 224 + 32 = 256, + * fitting in the 256-int32 output buffer without overshoot. + * + * Each iter runs all 4 sub-iters unconditionally — no mid-iter early exits. + * This makes the proof structurally simpler (mirrors the AArch64 design) + * while saving 3 fused-uops/iter and 3 highly-correlated mispredicted + * branches at the loop tail. + */ +rej_uniform_eta4_avx2_asm_loop: + cmpl $224, ctr // MLDSA_N - 32 + ja rej_uniform_eta4_avx2_asm_scalar + cmpl $120, pos // MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN - 16 + ja rej_uniform_eta4_avx2_asm_scalar + + // Load 16 bytes and extract 32 nibbles into f0 + vpmovzxbw (in, %rcx), f0 // load 16 bytes, zero-extend to 16 words + vpsllw $4, f0, f1 // shift left by 4 to align high nibbles + vpor f1, f0, f0 // OR: each word now has nibble duplicated + vpand mask, f0, f0 // mask to get 4 bits per byte + + // Check which nibbles are < 9 and compute eta - nibble + vpsubb bound, f0, f1 // f1 = nibble - 9 (negative if valid) + vpsubb f0, eta, f0 // f0 = 4 - nibble + vpmovmskb f1, good // extract sign bits (valid = 1) + + // Process first group of 8 nibbles (low 128 bits, low 64 bits of that) + vextracti128 $0, f0, g0 // extract low 128 bits + movzbl %r8b, %r10d // get low 8 bits of mask + vmovq (tab, tmp, 8), g1 // load shuffle indices from table + vpshufb g1, g0, g1 // compact valid nibbles + vpmovsxbd g1, f1 // sign-extend bytes to dwords + vmovdqu f1, (out, %rax, 4) // store 8 dwords + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr // ctr += popcount(low 8 bits) + shrl $8, good + addl $4, pos // consumed 4 input bytes + + // Process second group of 8 nibbles (low 128 bits, high 64 bits) + vpsrldq $8, g0, g0 // shift right to get next 8 nibbles + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr + shrl $8, good + addl $4, pos + + // Process third group of 8 nibbles (high 128 bits, low 64 bits) + vextracti128 $1, f0, g0 // extract high 128 bits + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr + shrl $8, good + addl $4, pos + + // Process fourth group of 8 nibbles (high 128 bits, high 64 bits) + vpsrldq $8, g0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr + addl $4, pos + + jmp rej_uniform_eta4_avx2_asm_loop + +/* + * Scalar fallback loop: process remaining bytes one nibble at a time. + * Each byte contains two 4-bit nibbles (low and high). + */ +rej_uniform_eta4_avx2_asm_scalar: + cmpl $256, ctr // MLDSA_N + jae rej_uniform_eta4_avx2_asm_done + cmpl $136, pos // MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN + jae rej_uniform_eta4_avx2_asm_done + + movzbl (in, %rcx), val // load 1 byte + incl pos + + // Process low nibble + movl val, %r10d + andl $0x0F, %r10d // extract low 4 bits + cmpl $9, %r10d + jae rej_uniform_eta4_avx2_asm_high_nibble + + movl $4, %r11d + subl %r10d, %r11d // 4 - nibble + movl %r11d, (out, %rax, 4) + incl ctr + + cmpl $256, ctr + jae rej_uniform_eta4_avx2_asm_done + +rej_uniform_eta4_avx2_asm_high_nibble: + // Process high nibble + shrl $4, val + andl $0x0F, val + cmpl $9, val + jae rej_uniform_eta4_avx2_asm_scalar + + movl $4, %r10d + subl val, %r10d // 4 - nibble + movl %r10d, (out, %rax, 4) + incl ctr + + jmp rej_uniform_eta4_avx2_asm_scalar + +rej_uniform_eta4_avx2_asm_done: + ret + +/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. + * Don't modify by hand -- this is auto-generated by scripts/autogen. */ +#undef out +#undef in +#undef tab +#undef ctr +#undef pos +#undef good +#undef cnt +#undef tmp +#undef val +#undef f0 +#undef f1 +#undef mask +#undef eta +#undef bound +#undef g0 +#undef g1 + +/* simpasm: footer-start */ +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED */ diff --git a/scripts/autogen b/scripts/autogen index 3f03b1547..c92bf5383 100755 --- a/scripts/autogen +++ b/scripts/autogen @@ -2793,6 +2793,18 @@ def hol_light_asm_joblist(): f"-Idev/fips202/x86_64/src -Imldsa/src/fips202/native/x86_64/src {x86_64_flags}", "x86_64", ), + ( + "rej_uniform_eta4_avx2_asm.S", + "dev/x86_64/src", + f"-DMLD_ARITH_BACKEND_X86_64_DEFAULT -Imldsa/src/native/x86_64/src -Icommon {x86_64_flags}", + "x86_64", + ), + ( + "rej_uniform_eta2_avx2_asm.S", + "dev/x86_64/src", + f"-DMLD_ARITH_BACKEND_X86_64_DEFAULT -Imldsa/src/native/x86_64/src -Icommon {x86_64_flags}", + "x86_64", + ), ] return joblist_aarch64 + joblist_x86_64 diff --git a/test/bench/bench_components_mldsa.c b/test/bench/bench_components_mldsa.c index 7b7f2c62c..717ea635a 100644 --- a/test/bench/bench_components_mldsa.c +++ b/test/bench/bench_components_mldsa.c @@ -97,6 +97,16 @@ static int bench(void) BENCH("poly_caddq", mld_poly_caddq((mld_poly *)data0)); + /* poly_uniform_eta_4x — exercises rej_uniform_eta{2,4}_avx2_asm on x86 */ +#if !defined(MLD_CONFIG_SERIAL_FIPS202_ONLY) && !defined(MLD_CONFIG_NO_KEYPAIR_API) + { + MLD_ALIGN mld_poly poly_eta0, poly_eta1, poly_eta2, poly_eta3; + BENCH("poly_uniform_eta_4x", + mld_poly_uniform_eta_4x(&poly_eta0, &poly_eta1, &poly_eta2, &poly_eta3, + (const uint8_t *)data0, 0, 1, 2, 3)) + } +#endif + return 0; } diff --git a/test/wycheproof/wycheproof_client.py b/test/wycheproof/wycheproof_client.py index ad070473c..177f68f4a 100755 --- a/test/wycheproof/wycheproof_client.py +++ b/test/wycheproof/wycheproof_client.py @@ -18,8 +18,8 @@ exec_prefix = os.environ.get("EXEC_WRAPPER", "") exec_prefix = exec_prefix.split(" ") if exec_prefix != "" else [] -# Pinned to a specific commit (2026-05-03). -WYCHEPROOF_COMMIT = "6d9d6de30f02e229dfc160323722c3ddac866181" +# Pinned to a specific commit (2026-06-04). +WYCHEPROOF_COMMIT = "4f5e05f71e6b724c20e2c1b6934c7bd7ef6d89e7" WYCHEPROOF_BASE_URL = f"https://raw.githubusercontent.com/C2SP/wycheproof/{WYCHEPROOF_COMMIT}/testvectors_v1" WYCHEPROOF_FILES = [