Packing / unpacking FP32s as BFloat16s with AVX2 (truncating)FP32 can be converted to BFloat16 simply by truncating the lower 16 bits. So we can do this with some bit operations.
Internally two sets of FP32s (i.e., elements in two __m256
s) are stored interleaved.
In case you’re memory bound, packing FP32s in this way may help your code’s performance.
Besides, I expect this code to run faster than _mm256_cvtph_ps
, which has a latency of 6/7 cycles depending on what platform you’re using. In contrast, fp32_from_bf16_interleaved
should have a latency of 1 cycle.
Converting FP32 to BFloat16 is usually done as a step of pre-processing so it’s latency shouldn’t be critical. But in case you care, it’s 2 cycles for bf16_interleaved_from_fp32
.
Code:
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 #include <x86intrin.h> #include <cmath>#include <cstdint>#include <cstdio>#include <utility> static const auto kLower16BitsClear = _mm256_set1_epi32(0xffff0000); // Using `std::pair<__m256, __m256>` triggers `-Wignored-attributes` and may// cause incorrect alignment (not tested though).struct EightFloatPair { __m256 x, y;}; template <class T>void debugging_dump_m256(const T& v) { for (int i = 0; i != 8; ++i) { printf("0x%08x ", reinterpret_cast<const std::uint32_t*>(&v)[i]); } printf("\n");} __m256i bf16_interleaved_from_fp32(__m256 x, __m256 y) { return _mm256_blend_epi16(_mm256_castps_si256(x), _mm256_srli_epi32(_mm256_castps_si256(y), 16), 0b01010101);} EightFloatPair fp32_from_bf16_interleaved(__m256i v) { auto x = _mm256_and_si256(v, kLower16BitsClear); auto y = _mm256_slli_epi32(v, 16); return EightFloatPair{_mm256_castsi256_ps(x), _mm256_castsi256_ps(y)};} __m256i load_fp32s_as_bf16(float* ptr) { auto p1 = _mm256_loadu_ps(&ptr[0]), p2 = _mm256_loadu_ps(&ptr[8]); return bf16_interleaved_from_fp32(p1, p2);} void store_bf16s_as_fp32(float* ptr, __m256i vs) { auto&& [x, y] = fp32_from_bf16_interleaved(vs); _mm256_storeu_ps(&ptr[0], x); _mm256_storeu_ps(&ptr[8], y);} float inputs[16], outputs[16]; int main() { for (int i = 0; i != 16; ++i) { inputs[i] = 1. / (i + 1); } store_bf16s_as_fp32(outputs, load_fp32s_as_bf16(inputs)); for (int i = 0; i != 16; ++i) { auto was = inputs[i]; auto now = outputs[i]; auto diff = fabs(was - now); auto pct = diff / was * 100; printf("Element %d was %f, now it's %f. The difference is %f (%.1f%%).\n", i, was, now, diff, pct); }}
Execution result:
12345678910111213141516 lement 0 was 1.000000, now it's 1.000000. The difference is 0.000000 (0.0%).Element 1 was 0.500000, now it's 0.500000. The difference is 0.000000 (0.0%).Element 2 was 0.333333, now it's 0.332031. The difference is 0.001302 (0.4%).Element 3 was 0.250000, now it's 0.250000. The difference is 0.000000 (0.0%).Element 4 was 0.200000, now it's 0.199219. The difference is 0.000781 (0.4%).Element 5 was 0.166667, now it's 0.166016. The difference is 0.000651 (0.4%).Element 6 was 0.142857, now it's 0.142578. The difference is 0.000279 (0.2%).Element 7 was 0.125000, now it's 0.125000. The difference is 0.000000 (0.0%).Element 8 was 0.111111, now it's 0.110840. The difference is 0.000271 (0.2%).Element 9 was 0.100000, now it's 0.099609. The difference is 0.000391 (0.4%).Element 10 was 0.090909, now it's 0.090820. The difference is 0.000089 (0.1%).Element 11 was 0.083333, now it's 0.083008. The difference is 0.000326 (0.4%).Element 12 was 0.076923, now it's 0.076660. The difference is 0.000263 (0.3%).Element 13 was 0.071429, now it's 0.071289. The difference is 0.000140 (0.2%).Element 14 was 0.066667, now it's 0.066406. The difference is 0.000260 (0.4%).Element 15 was 0.062500, now it's 0.062500. The difference is 0.000000 (0.0%).
FP32 can be converted to BFloat16 simply by truncating the lower 16 bits. So we can do this with some bit operations.
Internally two sets of FP32s (i.e., elements in two __m256
s) are stored interleaved.
In case you’re memory bound, packing FP32s in this way may help your code’s performance.
Besides, I expect this code to run faster than _mm256_cvtph_ps
, which has a latency of 6/7 cycles depending on what platform you’re using. In contrast, fp32_from_bf16_interleaved
should have a latency of 1 cycle.
Converting FP32 to BFloat16 is usually done as a step of pre-processing so it’s latency shouldn’t be critical. But in case you care, it’s 2 cycles for bf16_interleaved_from_fp32
.
Code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | #include <x86intrin.h> #include <cmath> #include <cstdint> #include <cstdio> #include <utility> static const auto kLower16BitsClear = _mm256_set1_epi32(0xffff0000); // Using `std::pair<__m256, __m256>` triggers `-Wignored-attributes` and may // cause incorrect alignment (not tested though). struct EightFloatPair { __m256 x, y; }; template <class T> void debugging_dump_m256(const T& v) { for (int i = 0; i != 8; ++i) { printf("0x%08x ", reinterpret_cast<const std::uint32_t*>(&v)[i]); } printf("\n"); } __m256i bf16_interleaved_from_fp32(__m256 x, __m256 y) { return _mm256_blend_epi16(_mm256_castps_si256(x), _mm256_srli_epi32(_mm256_castps_si256(y), 16), 0b01010101); } EightFloatPair fp32_from_bf16_interleaved(__m256i v) { auto x = _mm256_and_si256(v, kLower16BitsClear); auto y = _mm256_slli_epi32(v, 16); return EightFloatPair{_mm256_castsi256_ps(x), _mm256_castsi256_ps(y)}; } __m256i load_fp32s_as_bf16(float* ptr) { auto p1 = _mm256_loadu_ps(&ptr[0]), p2 = _mm256_loadu_ps(&ptr[8]); return bf16_interleaved_from_fp32(p1, p2); } void store_bf16s_as_fp32(float* ptr, __m256i vs) { auto&& [x, y] = fp32_from_bf16_interleaved(vs); _mm256_storeu_ps(&ptr[0], x); _mm256_storeu_ps(&ptr[8], y); } float inputs[16], outputs[16]; int main() { for (int i = 0; i != 16; ++i) { inputs[i] = 1. / (i + 1); } store_bf16s_as_fp32(outputs, load_fp32s_as_bf16(inputs)); for (int i = 0; i != 16; ++i) { auto was = inputs[i]; auto now = outputs[i]; auto diff = fabs(was - now); auto pct = diff / was * 100; printf("Element %d was %f, now it's %f. The difference is %f (%.1f%%).\n", i, was, now, diff, pct); } } |
Execution result:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | lement 0 was 1.000000, now it's 1.000000. The difference is 0.000000 (0.0%). Element 1 was 0.500000, now it's 0.500000. The difference is 0.000000 (0.0%). Element 2 was 0.333333, now it's 0.332031. The difference is 0.001302 (0.4%). Element 3 was 0.250000, now it's 0.250000. The difference is 0.000000 (0.0%). Element 4 was 0.200000, now it's 0.199219. The difference is 0.000781 (0.4%). Element 5 was 0.166667, now it's 0.166016. The difference is 0.000651 (0.4%). Element 6 was 0.142857, now it's 0.142578. The difference is 0.000279 (0.2%). Element 7 was 0.125000, now it's 0.125000. The difference is 0.000000 (0.0%). Element 8 was 0.111111, now it's 0.110840. The difference is 0.000271 (0.2%). Element 9 was 0.100000, now it's 0.099609. The difference is 0.000391 (0.4%). Element 10 was 0.090909, now it's 0.090820. The difference is 0.000089 (0.1%). Element 11 was 0.083333, now it's 0.083008. The difference is 0.000326 (0.4%). Element 12 was 0.076923, now it's 0.076660. The difference is 0.000263 (0.3%). Element 13 was 0.071429, now it's 0.071289. The difference is 0.000140 (0.2%). Element 14 was 0.066667, now it's 0.066406. The difference is 0.000260 (0.4%). Element 15 was 0.062500, now it's 0.062500. The difference is 0.000000 (0.0%). |