Vector Optimized Library of Kernels  3.3.0
Architecture-tuned implementations of math kernels
volk_avx512_intrinsics.h
Go to the documentation of this file.
1 /* -*- c++ -*- */
2 /*
3  * Copyright 2024-2026 Magnus Lundmark <magnuslundmark@gmail.com>
4  *
5  * This file is part of VOLK
6  *
7  * SPDX-License-Identifier: LGPL-3.0-or-later
8  */
9 
10 /*
11  * This file is intended to hold AVX512 intrinsics.
12  * They should be used in VOLK kernels to avoid copy-paste.
13  */
14 
15 #ifndef INCLUDE_VOLK_VOLK_AVX512_INTRINSICS_H_
16 #define INCLUDE_VOLK_VOLK_AVX512_INTRINSICS_H_
17 #include <immintrin.h>
18 
20 // Newton-Raphson refined reciprocal square root: 1/sqrt(a)
21 // One iteration doubles precision from ~12-bit to ~24-bit
22 // x1 = x0 * (1.5 - 0.5 * a * x0^2)
23 // Handles edge cases: +0 → +Inf, +Inf → 0
24 // Requires AVX512F
26 static inline __m512 _mm512_rsqrt_nr_ps(const __m512 a)
27 {
28  const __m512 HALF = _mm512_set1_ps(0.5f);
29  const __m512 THREE_HALFS = _mm512_set1_ps(1.5f);
30 
31  const __m512 x0 = _mm512_rsqrt14_ps(a); // +Inf for +0, 0 for +Inf
32 
33  // Newton-Raphson: x1 = x0 * (1.5 - 0.5 * a * x0^2)
34  __m512 x1 = _mm512_mul_ps(
35  x0, _mm512_fnmadd_ps(HALF, _mm512_mul_ps(_mm512_mul_ps(x0, x0), a), THREE_HALFS));
36 
37  // For +0 and +Inf inputs, x0 is correct but NR produces NaN due to Inf*0
38  // Blend: use x0 where a == +0 or a == +Inf, else use x1
39  __m512i a_si = _mm512_castps_si512(a);
40  __mmask16 zero_mask = _mm512_cmpeq_epi32_mask(a_si, _mm512_setzero_si512());
41  __mmask16 inf_mask = _mm512_cmpeq_epi32_mask(a_si, _mm512_set1_epi32(0x7F800000));
42  return _mm512_mask_blend_ps(zero_mask | inf_mask, x1, x0);
43 }
44 
46 // Place real parts of two complex vectors in output
47 // Requires AVX512F
49 static inline __m512 _mm512_real(const __m512 z1, const __m512 z2)
50 {
51  const __m512i idx =
52  _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
53  return _mm512_permutex2var_ps(z1, idx, z2);
54 }
55 
57 // Place imaginary parts of two complex vectors in output
58 // Requires AVX512F
60 static inline __m512 _mm512_imag(const __m512 z1, const __m512 z2)
61 {
62  const __m512i idx =
63  _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1);
64  return _mm512_permutex2var_ps(z1, idx, z2);
65 }
66 
68 // Approximate arctan(x) via polynomial expansion on the interval [-1, 1]
69 // Maximum relative error ~6.5e-7
70 // Polynomial evaluated via Horner's method
71 // Requires AVX512F
73 static inline __m512 _mm512_arctan_poly_avx512(const __m512 x)
74 {
75  const __m512 a1 = _mm512_set1_ps(+0x1.ffffeap-1f);
76  const __m512 a3 = _mm512_set1_ps(-0x1.55437p-2f);
77  const __m512 a5 = _mm512_set1_ps(+0x1.972be6p-3f);
78  const __m512 a7 = _mm512_set1_ps(-0x1.1436ap-3f);
79  const __m512 a9 = _mm512_set1_ps(+0x1.5785aap-4f);
80  const __m512 a11 = _mm512_set1_ps(-0x1.2f3004p-5f);
81  const __m512 a13 = _mm512_set1_ps(+0x1.01a37cp-7f);
82 
83  const __m512 x_times_x = _mm512_mul_ps(x, x);
84  __m512 arctan;
85  arctan = a13;
86  arctan = _mm512_fmadd_ps(x_times_x, arctan, a11);
87  arctan = _mm512_fmadd_ps(x_times_x, arctan, a9);
88  arctan = _mm512_fmadd_ps(x_times_x, arctan, a7);
89  arctan = _mm512_fmadd_ps(x_times_x, arctan, a5);
90  arctan = _mm512_fmadd_ps(x_times_x, arctan, a3);
91  arctan = _mm512_fmadd_ps(x_times_x, arctan, a1);
92  arctan = _mm512_mul_ps(x, arctan);
93 
94  return arctan;
95 }
96 
98 // Approximate arcsin(x) via polynomial expansion
99 // P(u) such that asin(x) = x * P(x^2) on |x| <= 0.5
100 // Maximum relative error ~1.5e-6
101 // Polynomial evaluated via Horner's method
102 // Requires AVX512F
104 static inline __m512 _mm512_arcsin_poly_avx512(const __m512 x)
105 {
106  const __m512 c0 = _mm512_set1_ps(0x1.ffffcep-1f);
107  const __m512 c1 = _mm512_set1_ps(0x1.55b648p-3f);
108  const __m512 c2 = _mm512_set1_ps(0x1.24d192p-4f);
109  const __m512 c3 = _mm512_set1_ps(0x1.0a788p-4f);
110 
111  const __m512 u = _mm512_mul_ps(x, x);
112  __m512 p = c3;
113  p = _mm512_fmadd_ps(u, p, c2);
114  p = _mm512_fmadd_ps(u, p, c1);
115  p = _mm512_fmadd_ps(u, p, c0);
116 
117  return _mm512_mul_ps(x, p);
118 }
119 
121 // Complex multiply: (a+bi) * (c+di) = (ac-bd) + i(ad+bc)
122 // Requires AVX512F
124 static inline __m512 _mm512_complexmul_ps(const __m512 x, const __m512 y)
125 {
126  const __m512 yl = _mm512_moveldup_ps(y); // Load yl with cr,cr,dr,dr ...
127  const __m512 yh = _mm512_movehdup_ps(y); // Load yh with ci,ci,di,di ...
128  const __m512 tmp1 = _mm512_mul_ps(x, yl); // tmp1 = ar*cr,ai*cr,br*dr,bi*dr ...
129  const __m512 x_swap =
130  _mm512_permute_ps(x, 0xB1); // Re-arrange x to be ai,ar,bi,br ...
131 
132  // Compute ar*cr-ai*ci, ai*cr+ar*ci, br*dr-bi*di, bi*dr+br*di using FMA
133  // We need: tmp1 - (x_swap * yh) for real parts, tmp1 + (x_swap * yh) for imag parts
134  // This is accomplished with addsub pattern
135  const __m512 tmp2 = _mm512_mul_ps(x_swap, yh); // ai*ci,ar*ci,bi*di,br*di
136 
137  // Use mask to create addsub behavior: subtract on even indices, add on odd
138  const __mmask16 addsub_mask = 0x5555; // 0101010101010101 in binary
139  return _mm512_mask_sub_ps(_mm512_add_ps(tmp1, tmp2), addsub_mask, tmp1, tmp2);
140 }
141 
143 // Complex conjugate multiply: (a+bi) * conj(c+di) = (ac+bd) + i(bc-ad)
144 // Requires AVX512F
146 static inline __m512 _mm512_complexconjugatemul_ps(const __m512 x, const __m512 y)
147 {
148  // Compute (a+bi) * conj(c+di) = (a+bi) * (c-di) = (ac+bd) + i(bc-ad)
149  const __m512 nswap = _mm512_permute_ps(x, 0xb1); // Swap real/imag: bi, ar, ...
150  const __m512 dreal = _mm512_moveldup_ps(y); // cr, cr, dr, dr, ...
151  const __m512 dimag = _mm512_movehdup_ps(y); // ci, ci, di, di, ...
152 
153  // Use integer xor for conjugation (AVX512F compatible)
154  const __m512i conjugator_i = _mm512_setr_epi32(0,
155  0x80000000,
156  0,
157  0x80000000,
158  0,
159  0x80000000,
160  0,
161  0x80000000,
162  0,
163  0x80000000,
164  0,
165  0x80000000,
166  0,
167  0x80000000,
168  0,
169  0x80000000);
170  const __m512 dimagconj = _mm512_castsi512_ps(_mm512_xor_epi32(
171  _mm512_castps_si512(dimag), conjugator_i)); // ci, -ci, di, -di, ...
172 
173  // Use FMA: x*dreal + nswap*dimagconj
174  return _mm512_fmadd_ps(nswap, dimagconj, _mm512_mul_ps(x, dreal));
175 }
176 
178 // Normalize complex vector: divide each complex number by its magnitude
179 // Requires AVX512F
181 static inline __m512 _mm512_normalize_ps(const __m512 val)
182 {
183  // Square the values: [r0^2, i0^2, r1^2, i1^2, ...]
184  __m512 tmp1 = _mm512_mul_ps(val, val);
185 
186  // Swap adjacent elements to get [i0^2, r0^2, i1^2, r1^2, ...]
187  const __m512 tmp1_swapped = _mm512_permute_ps(tmp1, 0xB1);
188 
189  // Add to get [r0^2+i0^2, i0^2+r0^2, r1^2+i1^2, i1^2+r1^2, ...]
190  __m512 mag_sq = _mm512_add_ps(tmp1, tmp1_swapped);
191 
192  // Take square root to get magnitude
193  const __m512 mag = _mm512_sqrt_ps(mag_sq);
194 
195  // Divide by magnitude
196  return _mm512_div_ps(val, mag);
197 }
198 
200 // Minimax polynomial for sin(x) on [-pi/4, pi/4]
201 // Coefficients via Remez algorithm (Sollya)
202 // Max |error| < 7.3e-9
203 // sin(x) = x + x^3 * (s1 + x^2 * (s2 + x^2 * s3))
204 // Requires AVX512F
206 static inline __m512 _mm512_sin_poly_avx512(const __m512 x)
207 {
208  const __m512 s1 = _mm512_set1_ps(-0x1.555552p-3f);
209  const __m512 s2 = _mm512_set1_ps(+0x1.110be2p-7f);
210  const __m512 s3 = _mm512_set1_ps(-0x1.9ab22ap-13f);
211 
212  const __m512 x2 = _mm512_mul_ps(x, x);
213  const __m512 x3 = _mm512_mul_ps(x2, x);
214 
215  __m512 poly = _mm512_fmadd_ps(x2, s3, s2);
216  poly = _mm512_fmadd_ps(x2, poly, s1);
217  return _mm512_fmadd_ps(x3, poly, x);
218 }
219 
221 // Minimax polynomial for cos(x) on [-pi/4, pi/4]
222 // Coefficients via Remez algorithm (Sollya)
223 // Max |error| < 1.1e-7
224 // cos(x) = 1 + x^2 * (c1 + x^2 * (c2 + x^2 * c3))
225 // Requires AVX512F
227 static inline __m512 _mm512_cos_poly_avx512(const __m512 x)
228 {
229  const __m512 c1 = _mm512_set1_ps(-0x1.fffff4p-2f);
230  const __m512 c2 = _mm512_set1_ps(+0x1.554a46p-5f);
231  const __m512 c3 = _mm512_set1_ps(-0x1.661be2p-10f);
232  const __m512 one = _mm512_set1_ps(1.0f);
233 
234  const __m512 x2 = _mm512_mul_ps(x, x);
235 
236  __m512 poly = _mm512_fmadd_ps(x2, c3, c2);
237  poly = _mm512_fmadd_ps(x2, poly, c1);
238  return _mm512_fmadd_ps(x2, poly, one);
239 }
240 
242 // Polynomial coefficients for log2(x)/(x-1) on [1, 2]
243 // Generated with Sollya: remez(log2(x)/(x-1), 6, [1+1b-20, 2])
244 // Max error: ~1.55e-6
245 //
246 // Usage: log2(x) ≈ poly(x) * (x - 1) for x ∈ [1, 2]
247 // Polynomial evaluated via Horner's method with FMA
248 // Requires AVX512F
250 static inline __m512 _mm512_log2_poly_avx512(const __m512 x)
251 {
252  const __m512 c0 = _mm512_set1_ps(+0x1.a8a726p+1f);
253  const __m512 c1 = _mm512_set1_ps(-0x1.0b7f7ep+2f);
254  const __m512 c2 = _mm512_set1_ps(+0x1.05d9ccp+2f);
255  const __m512 c3 = _mm512_set1_ps(-0x1.4d476cp+1f);
256  const __m512 c4 = _mm512_set1_ps(+0x1.04fc3ap+0f);
257  const __m512 c5 = _mm512_set1_ps(-0x1.c97982p-3f);
258  const __m512 c6 = _mm512_set1_ps(+0x1.57aa42p-6f);
259 
260  // Horner's method with FMA: c0 + x*(c1 + x*(c2 + ...))
261  __m512 poly = c6;
262  poly = _mm512_fmadd_ps(poly, x, c5);
263  poly = _mm512_fmadd_ps(poly, x, c4);
264  poly = _mm512_fmadd_ps(poly, x, c3);
265  poly = _mm512_fmadd_ps(poly, x, c2);
266  poly = _mm512_fmadd_ps(poly, x, c1);
267  poly = _mm512_fmadd_ps(poly, x, c0);
268  return poly;
269 }
270 
271 #endif /* INCLUDE_VOLK_VOLK_AVX512_INTRINSICS_H_ */
_mm512_complexmul_ps
static __m512 _mm512_complexmul_ps(const __m512 x, const __m512 y)
Definition: volk_avx512_intrinsics.h:124
volk_arch_defs.val
val
Definition: volk_arch_defs.py:57
_mm512_sin_poly_avx512
static __m512 _mm512_sin_poly_avx512(const __m512 x)
Definition: volk_avx512_intrinsics.h:206
_mm512_arcsin_poly_avx512
static __m512 _mm512_arcsin_poly_avx512(const __m512 x)
Definition: volk_avx512_intrinsics.h:104
_mm512_normalize_ps
static __m512 _mm512_normalize_ps(const __m512 val)
Definition: volk_avx512_intrinsics.h:181
_mm512_cos_poly_avx512
static __m512 _mm512_cos_poly_avx512(const __m512 x)
Definition: volk_avx512_intrinsics.h:227
_mm512_complexconjugatemul_ps
static __m512 _mm512_complexconjugatemul_ps(const __m512 x, const __m512 y)
Definition: volk_avx512_intrinsics.h:146
_mm512_arctan_poly_avx512
static __m512 _mm512_arctan_poly_avx512(const __m512 x)
Definition: volk_avx512_intrinsics.h:73
_mm512_imag
static __m512 _mm512_imag(const __m512 z1, const __m512 z2)
Definition: volk_avx512_intrinsics.h:60
_mm512_rsqrt_nr_ps
static __m512 _mm512_rsqrt_nr_ps(const __m512 a)
Definition: volk_avx512_intrinsics.h:26
_mm512_real
static __m512 _mm512_real(const __m512 z1, const __m512 z2)
Definition: volk_avx512_intrinsics.h:49
_mm512_log2_poly_avx512
static __m512 _mm512_log2_poly_avx512(const __m512 x)
Definition: volk_avx512_intrinsics.h:250