From b3a73ceb53faae18b2f602aa0ed7612c3af1e87d Mon Sep 17 00:00:00 2001 From: Bram Veenboer Date: Fri, 1 Aug 2025 14:53:36 +0200 Subject: [PATCH] Add LookupAVXBackend --- benchmarks/CMakeLists.txt | 3 + benchmarks/benchmark_lookup_avx.cpp | 13 ++ include/trigdx/lookup_avx.hpp | 22 +++ src/CMakeLists.txt | 2 +- src/lookup_avx.cpp | 202 ++++++++++++++++++++++++++++ tests/CMakeLists.txt | 4 + tests/test_lookup_avx.cpp | 19 +++ 7 files changed, 264 insertions(+), 1 deletion(-) create mode 100644 benchmarks/benchmark_lookup_avx.cpp create mode 100644 include/trigdx/lookup_avx.hpp create mode 100644 src/lookup_avx.cpp create mode 100644 tests/test_lookup_avx.cpp diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index 228c610..df44091 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -4,6 +4,9 @@ target_link_libraries(benchmark_reference PRIVATE trigdx) add_executable(benchmark_lookup benchmark_lookup.cpp) target_link_libraries(benchmark_lookup PRIVATE trigdx) +add_executable(benchmark_lookup_avx benchmark_lookup_avx.cpp) +target_link_libraries(benchmark_lookup_avx PRIVATE trigdx) + if(USE_MKL) add_executable(benchmark_mkl benchmark_mkl.cpp) target_link_libraries(benchmark_mkl PRIVATE trigdx) diff --git a/benchmarks/benchmark_lookup_avx.cpp b/benchmarks/benchmark_lookup_avx.cpp new file mode 100644 index 0000000..ef1dcf9 --- /dev/null +++ b/benchmarks/benchmark_lookup_avx.cpp @@ -0,0 +1,13 @@ +#include + +#include "benchmark_utils.hpp" + +int main() { + benchmark_sinf>(); + benchmark_cosf>(); + benchmark_sincosf>(); + + benchmark_sinf>(); + benchmark_cosf>(); + benchmark_sincosf>(); +} diff --git a/include/trigdx/lookup_avx.hpp b/include/trigdx/lookup_avx.hpp new file mode 100644 index 0000000..228a7ab --- /dev/null +++ b/include/trigdx/lookup_avx.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +#include "interface.hpp" + +template class LookupAVXBackend : public Backend { +public: + LookupAVXBackend(); + ~LookupAVXBackend() override; + + void init() override; + void compute_sinf(std::size_t n, const float *x, float *s) const override; + void compute_cosf(std::size_t n, const float *x, float *c) const override; + void compute_sincosf(std::size_t n, const float *x, float *s, + float *c) const override; + +private: + struct Impl; + std::unique_ptr impl; +}; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 44e58a1..c371642 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,4 @@ -add_library(trigdx reference.cpp lookup.cpp) +add_library(trigdx reference.cpp lookup.cpp lookup_avx.cpp) target_include_directories(trigdx PUBLIC ${PROJECT_SOURCE_DIR}/include) diff --git a/src/lookup_avx.cpp b/src/lookup_avx.cpp new file mode 100644 index 0000000..a3f8df2 --- /dev/null +++ b/src/lookup_avx.cpp @@ -0,0 +1,202 @@ +#include +#include +#include + +#include + +#include "trigdx/lookup_avx.hpp" + +template struct LookupAVXBackend::Impl { + std::vector lookup; + static constexpr std::size_t MASK = NR_SAMPLES - 1; + static constexpr float SCALE = NR_SAMPLES / (2.0f * float(M_PI)); + + void init() { + lookup.resize(NR_SAMPLES); + for (std::size_t i = 0; i < NR_SAMPLES; ++i) { + lookup[i] = std::sinf(i * (2.0f * float(M_PI) / NR_SAMPLES)); + } + } + + void compute_sincosf(std::size_t n, const float *x, float *s, + float *c) const { +#if defined(__AVX__) + constexpr std::size_t VL = 8; // AVX processes 8 floats + const __m256 scale = _mm256_set1_ps(SCALE); + const __m256i mask = _mm256_set1_epi32(MASK); + const __m256i quarter_pi = _mm256_set1_epi32(NR_SAMPLES / 4); + + std::size_t i = 0; + for (; i + VL <= n; i += VL) { + __m256 vx = _mm256_loadu_ps(&x[i]); + __m256 scaled = _mm256_mul_ps(vx, scale); + __m256i idx = _mm256_cvtps_epi32(scaled); + __m256i idx_cos = _mm256_add_epi32(idx, quarter_pi); + + idx = _mm256_and_si256(idx, mask); + idx_cos = _mm256_and_si256(idx_cos, mask); + +#if defined(__AVX2__) + __m256 sinv = _mm256_i32gather_ps(lookup.data(), idx, 4); + __m256 cosv = _mm256_i32gather_ps(lookup.data(), idx_cos, 4); +#else + // fallback gather for AVX1 + float sin_tmp[VL], cos_tmp[VL]; + int idx_a[VL], idxc_a[VL]; + _mm256_store_si256((__m256i *)idx_a, idx); + _mm256_store_si256((__m256i *)idxc_a, idx_cos); + for (std::size_t k = 0; k < VL; ++k) { + sin_tmp[k] = lookup[idx_a[k]]; + cos_tmp[k] = lookup[idxc_a[k]]; + } + __m256 sinv = _mm256_load_ps(sin_tmp); + __m256 cosv = _mm256_load_ps(cos_tmp); +#endif + _mm256_storeu_ps(&s[i], sinv); + _mm256_storeu_ps(&c[i], cosv); + } + + // scalar remainder + for (; i < n; ++i) { + std::size_t idx = static_cast(x[i] * SCALE) & MASK; + std::size_t idx_cos = (idx + NR_SAMPLES / 4) & MASK; + s[i] = lookup[idx]; + c[i] = lookup[idx_cos]; + } +#else + // No AVX: scalar path + for (std::size_t i = 0; i < n; ++i) { + std::size_t idx = static_cast(x[i] * SCALE) & MASK; + std::size_t idx_cos = (idx + NR_SAMPLES / 4) & MASK; + s[i] = lookup[idx]; + c[i] = lookup[idx_cos]; + } +#endif + } + + void compute_sinf(std::size_t n, const float *x, float *s) const { +#if defined(__AVX__) + constexpr std::size_t VL = 8; // AVX processes 8 floats + const __m256 scale = _mm256_set1_ps(SCALE); + const __m256i mask = _mm256_set1_epi32(MASK); + const __m256i quarter_pi = _mm256_set1_epi32(NR_SAMPLES / 4); + + std::size_t i = 0; + for (; i + VL <= n; i += VL) { + __m256 vx = _mm256_loadu_ps(&x[i]); + __m256 scaled = _mm256_mul_ps(vx, scale); + __m256i idx = _mm256_cvtps_epi32(scaled); + + idx = _mm256_and_si256(idx, mask); + +#if defined(__AVX2__) + __m256 sinv = _mm256_i32gather_ps(lookup.data(), idx, 4); +#else + // fallback gather for AVX1 + float sin_tmp[VL]; + int idx_a[VL], idxc_a[VL]; + _mm256_store_si256((__m256i *)idx_a, idx); + for (std::size_t k = 0; k < VL; ++k) { + sin_tmp[k] = lookup[idx_a[k]]; + } + __m256 sinv = _mm256_load_ps(sin_tmp); +#endif + _mm256_storeu_ps(&s[i], sinv); + } + + // scalar remainder + for (; i < n; ++i) { + std::size_t idx = static_cast(x[i] * SCALE) & MASK; + s[i] = lookup[idx]; + } +#else + // No AVX: scalar path + for (std::size_t i = 0; i < n; ++i) { + std::size_t idx = static_cast(x[i] * SCALE) & MASK; + s[i] = lookup[idx]; + } +#endif + } + + void compute_cosf(std::size_t n, const float *x, float *c) const { +#if defined(__AVX__) + constexpr std::size_t VL = 8; // AVX processes 8 floats + const __m256 scale = _mm256_set1_ps(SCALE); + const __m256i mask = _mm256_set1_epi32(MASK); + const __m256i quarter_pi = _mm256_set1_epi32(NR_SAMPLES / 4); + + std::size_t i = 0; + for (; i + VL <= n; i += VL) { + __m256 vx = _mm256_loadu_ps(&x[i]); + __m256 scaled = _mm256_mul_ps(vx, scale); + __m256i idx = _mm256_cvtps_epi32(scaled); + __m256i idx_cos = _mm256_add_epi32(idx, quarter_pi); + + idx_cos = _mm256_and_si256(idx_cos, mask); + +#if defined(__AVX2__) + __m256 cosv = _mm256_i32gather_ps(lookup.data(), idx_cos, 4); +#else + // fallback gather for AVX1 + float cos_tmp[VL]; + int idxc_a[VL]; + _mm256_store_si256((__m256i *)idxc_a, idx_cos); + for (std::size_t k = 0; k < VL; ++k) { + cos_tmp[k] = lookup[idxc_a[k]]; + } + __m256 cosv = _mm256_load_ps(cos_tmp); +#endif + _mm256_storeu_ps(&c[i], cosv); + } + + // scalar remainder + for (; i < n; ++i) { + std::size_t idx = static_cast(x[i] * SCALE) & MASK; + std::size_t idx_cos = (idx + NR_SAMPLES / 4) & MASK; + c[i] = lookup[idx_cos]; + } +#else + // No AVX: scalar path + for (std::size_t i = 0; i < n; ++i) { + std::size_t idx = static_cast(x[i] * SCALE) & MASK; + std::size_t idx_cos = (idx + NR_SAMPLES / 4) & MASK; + s[i] = lookup[idx]; + c[i] = lookup[idx_cos]; + } +#endif + } +}; + +template +LookupAVXBackend::LookupAVXBackend() + : impl(std::make_unique()) {} + +template +LookupAVXBackend::~LookupAVXBackend() = default; + +template void LookupAVXBackend::init() { + impl->init(); +} + +template +void LookupAVXBackend::compute_sinf(std::size_t n, const float *x, + float *s) const { + impl->compute_sinf(n, x, s); +} + +template +void LookupAVXBackend::compute_cosf(std::size_t n, const float *x, + float *c) const { + impl->compute_cosf(n, x, c); +} + +template +void LookupAVXBackend::compute_sincosf(std::size_t n, + const float *x, float *s, + float *c) const { + impl->compute_sincosf(n, x, s, c); +} + +// Explicit instantiations +template class LookupAVXBackend<16384>; +template class LookupAVXBackend<32768>; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b6bf8b7..b333768 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,6 +10,10 @@ FetchContent_MakeAvailable(catch2) add_executable(test_lookup test_lookup.cpp) target_link_libraries(test_lookup PRIVATE trigdx Catch2::Catch2WithMain) +# LookupAVX backend test +add_executable(test_lookup_avx test_lookup_avx.cpp) +target_link_libraries(test_lookup_avx PRIVATE trigdx Catch2::Catch2WithMain) + # MKL backend test if(USE_MKL) add_executable(test_mkl test_mkl.cpp) diff --git a/tests/test_lookup_avx.cpp b/tests/test_lookup_avx.cpp new file mode 100644 index 0000000..1509526 --- /dev/null +++ b/tests/test_lookup_avx.cpp @@ -0,0 +1,19 @@ +#include +#include + +#include "test_utils.hpp" + +TEST_CASE("sinf") { + test_sinf>(1e-2f); + test_sinf>(1e-2f); +} + +TEST_CASE("cosf") { + test_cosf>(1e-2f); + test_cosf>(1e-2f); +} + +TEST_CASE("sincosf") { + test_sincosf>(1e-2f); + test_sincosf>(1e-2f); +} \ No newline at end of file