diff --git a/src/Benchmark/Program.cs b/src/Benchmark/Program.cs index 5af19a1f..5727066f 100644 --- a/src/Benchmark/Program.cs +++ b/src/Benchmark/Program.cs @@ -17,8 +17,8 @@ namespace Benchmark var config = ManualConfig.Create(DefaultConfig.Instance) .With(new MemoryDiagnoser()); //, new InliningDiagnoser()); - //BenchmarkRunner.Run(config); - BenchmarkRunner.Run(config); + BenchmarkRunner.Run(config); + //BenchmarkRunner.Run(config); //Benchmark(new LinearAlgebra.DenseVectorAdd(10000000,1), 10, "Large (10'000'000) - 10x1 iterations"); //Benchmark(new LinearAlgebra.DenseVectorAdd(100,1000), 100, "Small (100) - 100x1000 iterations"); diff --git a/src/NativeProviders/MKL/fft.cpp b/src/NativeProviders/MKL/fft.cpp index a16af563..072c4ca3 100644 --- a/src/NativeProviders/MKL/fft.cpp +++ b/src/NativeProviders/MKL/fft.cpp @@ -23,12 +23,9 @@ inline MKL_LONG fft_create_1d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG n, } template -inline MKL_LONG fft_create_2d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG m, const MKL_LONG n, const Precision forward_scale, const Precision backward_scale, const DFTI_CONFIG_VALUE precision, const DFTI_CONFIG_VALUE domain) +inline MKL_LONG fft_create_md(DFTI_DESCRIPTOR_HANDLE* handle, MKL_LONG dimensions, MKL_LONG n[], const Precision forward_scale, const Precision backward_scale, const DFTI_CONFIG_VALUE precision, const DFTI_CONFIG_VALUE domain) { - MKL_LONG sizes[2]; - sizes[0] = m; - sizes[1] = n; - MKL_LONG status = DftiCreateDescriptor(handle, precision, domain, 2, sizes); + MKL_LONG status = DftiCreateDescriptor(handle, precision, domain, dimensions, n); DFTI_DESCRIPTOR_HANDLE descriptor = *handle; if (0 == status) status = DftiSetValue(descriptor, DFTI_FORWARD_SCALE, forward_scale); if (0 == status) status = DftiSetValue(descriptor, DFTI_BACKWARD_SCALE, backward_scale); @@ -59,14 +56,14 @@ extern "C" { return fft_create_1d(handle, n, forward_scale, backward_scale, DFTI_SINGLE, DFTI_COMPLEX); } - DLLEXPORT MKL_LONG z_fft_create_2d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG m, const MKL_LONG n, const double forward_scale, const double backward_scale) + DLLEXPORT MKL_LONG z_fft_create_multidim(DFTI_DESCRIPTOR_HANDLE* handle, MKL_LONG dimensions, MKL_LONG n[], const double forward_scale, const double backward_scale) { - return fft_create_2d(handle, m, n, forward_scale, backward_scale, DFTI_DOUBLE, DFTI_COMPLEX); + return fft_create_md(handle, dimensions, n, forward_scale, backward_scale, DFTI_DOUBLE, DFTI_COMPLEX); } - DLLEXPORT MKL_LONG c_fft_create_2d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG m, const MKL_LONG n, const float forward_scale, const float backward_scale) + DLLEXPORT MKL_LONG c_fft_create_multidim(DFTI_DESCRIPTOR_HANDLE* handle, MKL_LONG dimensions, MKL_LONG n[], const float forward_scale, const float backward_scale) { - return fft_create_2d(handle, m, n, forward_scale, backward_scale, DFTI_SINGLE, DFTI_COMPLEX); + return fft_create_md(handle, dimensions, n, forward_scale, backward_scale, DFTI_SINGLE, DFTI_COMPLEX); } DLLEXPORT MKL_LONG z_fft_forward(const DFTI_DESCRIPTOR_HANDLE handle, MKL_Complex16 x[]) diff --git a/src/Numerics/Providers/Common/Mkl/SafeNativeMethods.cs b/src/Numerics/Providers/Common/Mkl/SafeNativeMethods.cs index 1aa93443..bab628ca 100644 --- a/src/Numerics/Providers/Common/Mkl/SafeNativeMethods.cs +++ b/src/Numerics/Providers/Common/Mkl/SafeNativeMethods.cs @@ -393,10 +393,10 @@ namespace MathNet.Numerics.Providers.Common.Mkl internal static extern int c_fft_create([Out] out IntPtr handle, int n, float forward_scale, float backward_scale); [DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)] - internal static extern int z_fft_create_2d([Out] out IntPtr handle, int m, int n, double forward_scale, double backward_scale); + internal static extern int z_fft_create_multidim([Out] out IntPtr handle, int dimensions, [In] int[] n, double forward_scale, double backward_scale); [DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)] - internal static extern int c_fft_create_2d([Out] out IntPtr handle, int m, int n, float forward_scale, float backward_scale); + internal static extern int c_fft_create_multidim([Out] out IntPtr handle, int dimensions, [In] int[] n, float forward_scale, float backward_scale); [DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)] internal static extern int z_fft_forward([In] IntPtr handle, [In, Out] Complex[] x); diff --git a/src/Numerics/Providers/FourierTransform/IFourierTransformProvider.cs b/src/Numerics/Providers/FourierTransform/IFourierTransformProvider.cs index f5fea08e..1e3e5187 100644 --- a/src/Numerics/Providers/FourierTransform/IFourierTransformProvider.cs +++ b/src/Numerics/Providers/FourierTransform/IFourierTransformProvider.cs @@ -60,5 +60,8 @@ namespace MathNet.Numerics.Providers.FourierTransform Complex[] Forward(Complex[] complexTimeSpace, FourierTransformScaling scaling); Complex[] Backward(Complex[] complexFrequenceSpace, FourierTransformScaling scaling); + + void ForwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling); + void BackwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling); } } diff --git a/src/Numerics/Providers/FourierTransform/ManagedFourierTransformProvider.cs b/src/Numerics/Providers/FourierTransform/ManagedFourierTransformProvider.cs index b5c89c3e..ce3f69dc 100644 --- a/src/Numerics/Providers/FourierTransform/ManagedFourierTransformProvider.cs +++ b/src/Numerics/Providers/FourierTransform/ManagedFourierTransformProvider.cs @@ -26,12 +26,11 @@ // OTHER DEALINGS IN THE SOFTWARE. // -using System.Collections; +using System; using MathNet.Numerics.IntegralTransforms; namespace MathNet.Numerics.Providers.FourierTransform { - #if !NOSYSNUMERICS using Complex = System.Numerics.Complex; #endif @@ -107,5 +106,15 @@ namespace MathNet.Numerics.Providers.FourierTransform BackwardInplace(work, scaling); return work; } + + public void ForwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling) + { + throw new NotImplementedException(); + } + + public void BackwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling) + { + throw new NotImplementedException(); + } } } diff --git a/src/Numerics/Providers/FourierTransform/Mkl/MklFourierTransformProvider.cs b/src/Numerics/Providers/FourierTransform/Mkl/MklFourierTransformProvider.cs index f9ab7d94..fe21d002 100644 --- a/src/Numerics/Providers/FourierTransform/Mkl/MklFourierTransformProvider.cs +++ b/src/Numerics/Providers/FourierTransform/Mkl/MklFourierTransformProvider.cs @@ -40,7 +40,7 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl class Kernel { public IntPtr Handle; - public int Length; + public int[] Dimensions; public FourierTransformScaling Scaling; } @@ -63,8 +63,8 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl MklProvider.Load(minRevision: 11); // we only support exactly one major version, since major version changes imply a breaking change. - int fftMajor = SafeNativeMethods.query_capability((int)ProviderCapability.FourierTransformMajor); - int fftMinor = SafeNativeMethods.query_capability((int)ProviderCapability.FourierTransformMinor); + int fftMajor = SafeNativeMethods.query_capability((int) ProviderCapability.FourierTransformMajor); + int fftMinor = SafeNativeMethods.query_capability((int) ProviderCapability.FourierTransformMinor); if (!(fftMajor == 1 && fftMinor >= 0)) { throw new NotSupportedException(string.Format("MKL Native Provider not compatible. Expecting fourier transform v1 but provider implements v{0}.", fftMajor)); @@ -150,18 +150,78 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl { kernel = new Kernel { - Length = length, + Dimensions = new[] {length}, Scaling = scaling }; + SafeNativeMethods.z_fft_create(out kernel.Handle, length, ForwardScaling(scaling, length), BackwardScaling(scaling, length)); + return kernel; } - if (kernel.Length != length || kernel.Scaling != scaling) + if (kernel.Dimensions.Length != 1 || kernel.Dimensions[0] != length || kernel.Scaling != scaling) { SafeNativeMethods.x_fft_free(ref kernel.Handle); SafeNativeMethods.z_fft_create(out kernel.Handle, length, ForwardScaling(scaling, length), BackwardScaling(scaling, length)); - kernel.Length = length; + kernel.Dimensions = new[] {length}; + kernel.Scaling = scaling; + return kernel; + } + + return kernel; + } + + Kernel Configure(int[] dimensions, FourierTransformScaling scaling) + { + if (dimensions.Length == 1) + { + return Configure(dimensions[0], scaling); + } + + Kernel kernel = Interlocked.Exchange(ref _kernel, null); + + if (kernel == null) + { + kernel = new Kernel + { + Dimensions = dimensions, + Scaling = scaling + }; + + long length = 1; + for (int i = 0; i < dimensions.Length; i++) + { + length *= dimensions[i]; + } + + SafeNativeMethods.z_fft_create_multidim(out kernel.Handle, dimensions.Length, dimensions, ForwardScaling(scaling, length), BackwardScaling(scaling, length)); + return kernel; + } + + bool mismatch = kernel.Dimensions.Length != dimensions.Length || kernel.Scaling != scaling; + if (!mismatch) + { + for (int i = 0; i < dimensions.Length; i++) + { + if (dimensions[i] != kernel.Dimensions[i]) + { + mismatch = true; + break; + } + } + } + + if (mismatch) + { + long length = 1; + for (int i = 0; i < dimensions.Length; i++) + { + length *= dimensions[i]; + } + + SafeNativeMethods.x_fft_free(ref kernel.Handle); + SafeNativeMethods.z_fft_create_multidim(out kernel.Handle, dimensions.Length, dimensions, ForwardScaling(scaling, length), BackwardScaling(scaling, length)); + kernel.Dimensions = dimensions; kernel.Scaling = scaling; return kernel; } @@ -208,7 +268,17 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl return work; } - static double ForwardScaling(FourierTransformScaling scaling, int length) + public void ForwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling) + { + throw new NotImplementedException(); + } + + public void BackwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling) + { + throw new NotImplementedException(); + } + + static double ForwardScaling(FourierTransformScaling scaling, long length) { switch (scaling) { @@ -221,7 +291,7 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl } } - static double BackwardScaling(FourierTransformScaling scaling, int length) + static double BackwardScaling(FourierTransformScaling scaling, long length) { switch (scaling) { @@ -241,4 +311,4 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl } } -#endif +#endif \ No newline at end of file