Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 72 additions & 25 deletions compression/compress-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,16 @@ struct CompressTraits<float> {
const hn::Repartition<float, decltype(dbf16)> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + 0 * NF);
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + 1 * NF);
const VF f2 = hn::LoadU(df, packed.ptr + packed_ofs + 2 * NF);
const VF f3 = hn::LoadU(df, packed.ptr + packed_ofs + 3 * NF);
const hn::Repartition<uint8_t, decltype(df)> du8;
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
const VF f0 = hn::BitCast(
df, hn::LoadU(du8, src_bytes + (packed_ofs + 0 * NF) * sizeof(Packed)));
const VF f1 = hn::BitCast(
df, hn::LoadU(du8, src_bytes + (packed_ofs + 1 * NF) * sizeof(Packed)));
const VF f2 = hn::BitCast(
df, hn::LoadU(du8, src_bytes + (packed_ofs + 2 * NF) * sizeof(Packed)));
const VF f3 = hn::BitCast(
df, hn::LoadU(du8, src_bytes + (packed_ofs + 3 * NF) * sizeof(Packed)));
raw0 = hn::OrderedDemote2To(dbf16, f0, f1);
raw1 = hn::OrderedDemote2To(dbf16, f2, f3);
}
Expand All @@ -104,8 +110,12 @@ struct CompressTraits<float> {
static HWY_INLINE void Load2(DF df, const PackedSpan<const Packed>& packed,
const size_t packed_ofs, VF& raw0, VF& raw1) {
const size_t N = hn::Lanes(df);
raw0 = hn::LoadU(df, packed.ptr + packed_ofs);
raw1 = hn::LoadU(df, packed.ptr + packed_ofs + N);
const hn::Repartition<uint8_t, DF> du8;
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
raw0 = hn::BitCast(df,
hn::LoadU(du8, src_bytes + packed_ofs * sizeof(Packed)));
raw1 = hn::BitCast(
df, hn::LoadU(du8, src_bytes + (packed_ofs + N) * sizeof(Packed)));
}

template <class DD, HWY_IF_F64_D(DD), class VD = hn::Vec<DD>>
Expand All @@ -114,9 +124,12 @@ struct CompressTraits<float> {
const hn::Rebind<float, DD> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
// Two half loads are likely cheaper than one full + UpperHalf.
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + 0 * NF);
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + 1 * NF);
const hn::Repartition<uint8_t, decltype(df)> du8;
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
const VF f0 = hn::BitCast(
df, hn::LoadU(du8, src_bytes + (packed_ofs + 0 * NF) * sizeof(Packed)));
const VF f1 = hn::BitCast(
df, hn::LoadU(du8, src_bytes + (packed_ofs + 1 * NF) * sizeof(Packed)));
raw0 = hn::PromoteTo(dd, f0);
raw1 = hn::PromoteTo(dd, f1);
}
Expand All @@ -128,21 +141,31 @@ struct CompressTraits<float> {
const hn::Repartition<float, decltype(dbf)> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
const hn::Repartition<uint8_t, decltype(dbf)> du8;
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);

size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + i);
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + i + NF);
const VF f0 = hn::BitCast(
df, hn::LoadU(du8, src_bytes + (packed_ofs + i) * sizeof(Packed)));
const VF f1 = hn::BitCast(
df,
hn::LoadU(du8, src_bytes + (packed_ofs + i + NF) * sizeof(Packed)));
hn::StoreU(hn::OrderedDemote2To(dbf, f0, f1), dbf, raw + i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
const size_t remaining2 = remaining - HWY_MIN(remaining, NF);
const VF f0 = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
const VF f1 = hn::LoadN(df, packed.ptr + packed_ofs + i + NF, remaining2);
const VF f0 = hn::BitCast(
df, hn::LoadN(du8, src_bytes + (packed_ofs + i) * sizeof(Packed),
HWY_MIN(remaining, NF) * sizeof(Packed)));
const VF f1 = hn::BitCast(
df, hn::LoadN(du8,
src_bytes + (packed_ofs + i + NF) * sizeof(Packed),
remaining2 * sizeof(Packed)));
hn::StoreU(hn::OrderedDemote2To(dbf, f0, f1), dbf, raw + i);
}
}
Expand All @@ -153,18 +176,23 @@ struct CompressTraits<float> {
float* HWY_RESTRICT raw, size_t num) {
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
const hn::Repartition<uint8_t, DF> du8;
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);

size_t i = 0;
if (num >= NF) {
for (; i <= num - NF; i += NF) {
const VF vf = hn::LoadU(df, packed.ptr + packed_ofs + i);
const VF vf = hn::BitCast(
df, hn::LoadU(du8, src_bytes + (packed_ofs + i) * sizeof(Packed)));
hn::StoreU(vf, df, raw + i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < NF);
if (HWY_UNLIKELY(remaining != 0)) {
const VF vf = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
const VF vf = hn::BitCast(
df, hn::LoadN(du8, src_bytes + (packed_ofs + i) * sizeof(Packed),
remaining * sizeof(Packed)));
hn::StoreU(vf, df, raw + i); // adds zero padding
}
}
Expand All @@ -176,18 +204,23 @@ struct CompressTraits<float> {
const hn::Rebind<float, DD> df;
using VF = hn::Vec<decltype(df)>;
const size_t ND = hn::Lanes(dd);
const hn::Repartition<uint8_t, decltype(df)> du8;
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);

size_t i = 0;
if (num >= ND) {
for (; i <= num - ND; i += ND) {
const VF vf = hn::LoadU(df, packed.ptr + packed_ofs + i);
const VF vf = hn::BitCast(
df, hn::LoadU(du8, src_bytes + (packed_ofs + i) * sizeof(Packed)));
hn::StoreU(hn::PromoteTo(dd, vf), dd, raw + i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < ND);
if (HWY_UNLIKELY(remaining != 0)) {
const VF vf = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
const VF vf = hn::BitCast(
df, hn::LoadN(du8, src_bytes + (packed_ofs + i) * sizeof(Packed),
remaining * sizeof(Packed)));
hn::StoreU(hn::PromoteTo(dd, vf), dd, raw + i); // adds zero padding
}
}
Expand Down Expand Up @@ -265,9 +298,13 @@ struct CompressTraits<BF16> {
const PackedSpan<const Packed>& packed,
const size_t packed_ofs, hn::Vec<DBF16>& raw0,
hn::Vec<DBF16>& raw1) {
const hn::Repartition<uint8_t, DBF16> du8;
const size_t N16 = hn::Lanes(dbf16);
raw0 = hn::LoadU(dbf16, packed.ptr + packed_ofs);
raw1 = hn::LoadU(dbf16, packed.ptr + packed_ofs + N16);
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
raw0 = hn::BitCast(dbf16,
hn::LoadU(du8, src_bytes + packed_ofs * sizeof(Packed)));
raw1 = hn::BitCast(
dbf16, hn::LoadU(du8, src_bytes + (packed_ofs + N16) * sizeof(Packed)));
}

template <class DF, HWY_IF_F32_D(DF)>
Expand All @@ -276,7 +313,10 @@ struct CompressTraits<BF16> {
hn::Vec<DF>& raw1) {
const hn::Repartition<BF16, decltype(df)> dbf;
using VBF = hn::Vec<decltype(dbf)>;
const VBF packed0 = hn::LoadU(dbf, packed.ptr + packed_ofs);
const hn::Repartition<uint8_t, decltype(df)> du8;
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
const VBF packed0 = hn::BitCast(
dbf, hn::LoadU(du8, src_bytes + packed_ofs * sizeof(Packed)));
raw0 = hn::PromoteLowerTo(df, packed0);
raw1 = hn::PromoteUpperTo(df, packed0);
}
Expand All @@ -287,20 +327,24 @@ struct CompressTraits<BF16> {
BF16* HWY_RESTRICT raw, size_t num) {
using VBF = hn::Vec<decltype(dbf)>;
const size_t N16 = hn::Lanes(dbf);
const hn::Repartition<uint8_t, DBF> du8;
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);

size_t i = 0;
if (num >= N16) {
for (; i <= num - N16; i += N16) {
const VBF packed0 = hn::LoadU(dbf, packed.ptr + packed_ofs + i);
const VBF packed0 = hn::BitCast(
dbf, hn::LoadU(du8, src_bytes + (packed_ofs + i) * sizeof(Packed)));
hn::StoreU(packed0, dbf, raw + i);
}
}

const size_t remaining = num - i;
HWY_DASSERT(remaining < N16);
if (HWY_UNLIKELY(remaining != 0)) {
const VBF packed0 =
hn::LoadN(dbf, packed.ptr + packed_ofs + i, remaining);
const VBF packed0 = hn::BitCast(
dbf, hn::LoadN(du8, src_bytes + (packed_ofs + i) * sizeof(Packed),
remaining * sizeof(Packed)));
hn::StoreU(packed0, dbf, raw + i);
}
}
Expand Down Expand Up @@ -363,8 +407,11 @@ struct CompressTraits<BF16> {
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
const VBF packed0 =
hn::LoadN(dbf, packed.ptr + packed_ofs + i, remaining);
const hn::Repartition<uint8_t, decltype(dbf)> du8;
const uint8_t* src_bytes = reinterpret_cast<const uint8_t*>(packed.ptr);
const VBF packed0 = hn::BitCast(
dbf, hn::LoadN(du8, src_bytes + (packed_ofs + i) * sizeof(Packed),
remaining * sizeof(Packed)));
const VF raw0 = hn::PromoteLowerTo(df, packed0);
const VF raw1 = hn::PromoteUpperTo(df, packed0);
// If at most one vector, the first store adds zero padding. Check before
Expand Down
Loading