Merge pull request #10 from astron-rd/add-python-interface
Add Python interface
This commit is contained in:
@@ -3,18 +3,22 @@ project(trigdx LANGUAGES CXX)
|
|||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
option(TRIGDX_USE_MKL "Enable Intel MKL backend" OFF)
|
option(TRIGDX_USE_MKL "Enable Intel MKL backend" OFF)
|
||||||
option(TRIGDX_USE_GPU "Enable GPU backend" OFF)
|
option(TRIGDX_USE_GPU "Enable GPU backend" OFF)
|
||||||
option(TRIGDX_USE_XSIMD "Enable XSIMD backend" OFF)
|
option(TRIGDX_USE_XSIMD "Enable XSIMD backend" OFF)
|
||||||
option(TRIGDX_BUILD_TESTS "Build tests" ON)
|
option(TRIGDX_BUILD_TESTS "Build tests" ON)
|
||||||
option(TRIGDX_BUILD_BENCHMARKS "Build tests" ON)
|
option(TRIGDX_BUILD_BENCHMARKS "Build tests" ON)
|
||||||
|
option(TRIGDX_BUILD_PYTHON "Build Python interface" ON)
|
||||||
|
|
||||||
configure_file(
|
configure_file(
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/trigdx_config.hpp.in
|
${CMAKE_CURRENT_SOURCE_DIR}/cmake/trigdx_config.hpp.in
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/include/trigdx/trigdx_config.hpp @ONLY)
|
${CMAKE_CURRENT_BINARY_DIR}/include/trigdx/trigdx_config.hpp @ONLY)
|
||||||
|
|
||||||
if(TRIGDX_BUILD_TESTS OR TRIGDX_BUILD_BENCHMARKS)
|
if(TRIGDX_BUILD_TESTS
|
||||||
|
OR TRIGDX_BUILD_BENCHMARKS
|
||||||
|
OR TRIGDX_BUILD_PYTHON)
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -31,3 +35,7 @@ endif()
|
|||||||
if(TRIGDX_BUILD_BENCHMARKS)
|
if(TRIGDX_BUILD_BENCHMARKS)
|
||||||
add_subdirectory(benchmarks)
|
add_subdirectory(benchmarks)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(TRIGDX_BUILD_PYTHON)
|
||||||
|
add_subdirectory(python)
|
||||||
|
endif()
|
||||||
|
|||||||
8
python/CMakeLists.txt
Normal file
8
python/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
FetchContent_Declare(
|
||||||
|
pybind11
|
||||||
|
GIT_REPOSITORY https://github.com/pybind/pybind11.git
|
||||||
|
GIT_TAG v3.0.0)
|
||||||
|
FetchContent_MakeAvailable(pybind11)
|
||||||
|
|
||||||
|
pybind11_add_module(pytrigdx bindings.cpp)
|
||||||
|
target_link_libraries(pytrigdx PRIVATE trigdx)
|
||||||
94
python/bindings.cpp
Normal file
94
python/bindings.cpp
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <pybind11/numpy.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <trigdx/trigdx.hpp>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
py::array_t<T>
|
||||||
|
compute_sin(const Backend &backend,
|
||||||
|
py::array_t<T, py::array::c_style | py::array::forcecast> x) {
|
||||||
|
const size_t n = x.shape(0);
|
||||||
|
if (x.ndim() != 1) {
|
||||||
|
throw py::value_error("Input array must be 1-dimensional");
|
||||||
|
}
|
||||||
|
|
||||||
|
const T *x_ptr = x.data();
|
||||||
|
|
||||||
|
py::array_t<float> s(n);
|
||||||
|
T *s_ptr = s.mutable_data();
|
||||||
|
|
||||||
|
backend.compute_sinf(n, x_ptr, s_ptr);
|
||||||
|
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
py::array_t<T>
|
||||||
|
compute_cos(const Backend &backend,
|
||||||
|
py::array_t<T, py::array::c_style | py::array::forcecast> x) {
|
||||||
|
const size_t n = x.shape(0);
|
||||||
|
if (x.ndim() != 1) {
|
||||||
|
throw py::value_error("Input array must be 1-dimensional");
|
||||||
|
}
|
||||||
|
|
||||||
|
const T *x_ptr = x.data();
|
||||||
|
|
||||||
|
py::array_t<T> c(n);
|
||||||
|
T *c_ptr = c.mutable_data();
|
||||||
|
|
||||||
|
backend.compute_cosf(n, x_ptr, c_ptr);
|
||||||
|
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::tuple<py::array_t<T>, py::array_t<T>>
|
||||||
|
compute_sincos(const Backend &backend,
|
||||||
|
py::array_t<T, py::array::c_style | py::array::forcecast> x) {
|
||||||
|
const size_t n = x.shape(0);
|
||||||
|
if (x.ndim() != 1) {
|
||||||
|
throw py::value_error("Input array must be 1-dimensional");
|
||||||
|
}
|
||||||
|
|
||||||
|
const T *x_ptr = x.data();
|
||||||
|
|
||||||
|
py::array_t<T> s(n);
|
||||||
|
py::array_t<T> c(n);
|
||||||
|
|
||||||
|
backend.compute_sincosf(n, x_ptr, s.mutable_data(), c.mutable_data());
|
||||||
|
|
||||||
|
return std::make_tuple(s, c);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename BackendType>
|
||||||
|
void bind_backend(py::module &m, const char *name) {
|
||||||
|
py::class_<BackendType, Backend, std::shared_ptr<BackendType>>(m, name)
|
||||||
|
.def(py::init<>())
|
||||||
|
.def("compute_sinf", &compute_sin<float>)
|
||||||
|
.def("compute_cosf", &compute_cos<float>)
|
||||||
|
.def("compute_sincosf", &compute_sincos<float>);
|
||||||
|
}
|
||||||
|
|
||||||
|
PYBIND11_MODULE(pytrigdx, m) {
|
||||||
|
py::class_<Backend, std::shared_ptr<Backend>>(m, "Backend")
|
||||||
|
.def("init", &Backend::init);
|
||||||
|
|
||||||
|
bind_backend<ReferenceBackend>(m, "Reference");
|
||||||
|
bind_backend<LookupBackend<16384>>(m, "Lookup16K");
|
||||||
|
bind_backend<LookupBackend<32768>>(m, "Lookup32K");
|
||||||
|
bind_backend<LookupAVXBackend<16384>>(m, "LookupAVX16K");
|
||||||
|
bind_backend<LookupAVXBackend<32768>>(m, "LookupAVX32K");
|
||||||
|
#if defined(TRIGDX_USE_MKL)
|
||||||
|
bind_backend<MKLBackend>(m, "MKL");
|
||||||
|
#endif
|
||||||
|
#if defined(TRIGDX_USE_GPU)
|
||||||
|
bind_backend<GPUBackend>(m, "GPU");
|
||||||
|
#endif
|
||||||
|
#if defined(TRIGDX_USE_XSIMD)
|
||||||
|
bind_backend<LookupXSIMDBackend<16384>>(m, "LookupXSIMD16K");
|
||||||
|
bind_backend<LookupXSIMDBackend<32768>>(m, "LookupXSIMD32K");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user