diff --git a/hwy/contrib/math/fast_math-inl.h b/hwy/contrib/math/fast_math-inl.h index afe7d213db..47d94ba3ef 100644 --- a/hwy/contrib/math/fast_math-inl.h +++ b/hwy/contrib/math/fast_math-inl.h @@ -56,6 +56,84 @@ HWY_INLINE void ReduceAngleTan(D d, V ang, V& x_red, V& sign) { } // namespace impl +namespace impl { + +template +struct FastExpImpl {}; + +template <> +struct FastExpImpl { + // Rounds float toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return ConvertInRangeTo(Rebind(), x); + } + + // Computes 2^x, where x is an integer. + template + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const VI32 kOffset = Set(di32, 0x7F); + return BitCast(d, ShiftLeft<23>(Add(x, kOffset))); + } + + // Sets the exponent of 'x' to 2^e. + template + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kMinusLn2 ~= -ln(2) + const V kMinusLn2 = Set(d, -0.69314718056f); + + // Extended precision modular arithmetic. + const V qf = ConvertTo(d, q); + return MulAdd(qf, kMinusLn2, x); + } +}; + +#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64 +template <> +struct FastExpImpl { + // Rounds double toward zero and returns as int32_t. + template + HWY_INLINE Vec> ToInt32(D /*unused*/, V x) { + return DemoteInRangeTo(Rebind(), x); + } + + // Computes 2^x, where x is an integer. + template + HWY_INLINE Vec Pow2I(D d, VI32 x) { + const Rebind di32; + const Rebind di64; + const VI32 kOffset = Set(di32, 0x3FF); + return BitCast(d, ShiftLeft<52>(PromoteTo(di64, Add(x, kOffset)))); + } + + // Sets the exponent of 'x' to 2^e. + template + HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) { + const VI32 y = ShiftRight<1>(e); + return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y))); + } + + template + HWY_INLINE V ExpReduce(D d, V x, VI32 q) { + // kMinusLn2 ~= -ln(2) + const V kMinusLn2 = Set(d, -0.6931471805599453); + + // Extended precision modular arithmetic. + const V qf = PromoteTo(d, q); + return MulAdd(qf, kMinusLn2, x); + } +}; +#endif + +} // namespace impl + /** * Fast approximation of tan(x). * @@ -778,6 +856,69 @@ HWY_INLINE V FastLog(D d, V x) { return MulAdd(exp, kLn2, approx); } +/** + * Fast approximation of exp(x). + * + * Valid Lane Types: float32, float64 + * Max ULP Error: 1 for float32 [-FLT_MAX, -87] + * Max ULP Error: 1 for float64 [-DBL_MAX, -708] + * Max Relative Error: 0.06% for float32 [-87, 88] + * Max Relative Error: 0.06% for float64 [-708, 706] + * Average Relative Error: 0.05% for float32 [-87, 88] + * Average Relative Error: 0.06% for float64 [-708, 706] + * Valid Range: float32[-FLT_MAX, +88], float64[-DBL_MAX, +706] + * + * @return e^x + */ +template +HWY_INLINE V FastExp(D d, V x) { + using T = TFromD; + impl::FastExpImpl impl; + + const V kHalf = Set(d, static_cast(+0.5)); + const V kLowerBound = + Set(d, static_cast((sizeof(T) == 4 ? -104.0 : -1000.0))); + const V kNegZero = Set(d, static_cast(-0.0)); + + const V kOneOverLog2 = Set(d, static_cast(+1.442695040888963407359924681)); + + using TI = MakeSigned; + const Rebind di; + const auto rounded_offs = BitCast( + d, OrAnd(BitCast(di, kHalf), BitCast(di, x), BitCast(di, kNegZero))); + + const auto q = impl.ToInt32(d, MulAdd(x, kOneOverLog2, rounded_offs)); + + // Reduce + const auto x_red = impl.ExpReduce(d, x, q); + + // New logic: + // x_in = |x_red| / 2 -> absorbed into coefficients + // if x_red < 0: swap num/den + + auto y = Abs(x_red); + + const auto a = Set(d, static_cast(-1757.05)); + const auto b = Set(d, static_cast(-3128.2)); + const auto c = Set(d, static_cast(1406.95)); + const auto d_coef = Set(d, static_cast(-3130.2)); + + // res = (Ay + B) / (Cy + D) + auto num = MulAdd(a, y, b); + auto den = MulAdd(c, y, d_coef); + + // If x_red < 0, swap num/den + auto final_num = IfNegativeThenElse(x_red, den, num); + auto final_den = IfNegativeThenElse(x_red, num, den); + + auto approx = Div(final_num, final_den); + + const V res = impl.LoadExpShortRange(d, approx, q); + + // Handle underflow + return IfThenElseZero(Ge(x, kLowerBound), res); +} + template HWY_NOINLINE V CallFastAtan(const D d, VecArg x) { return FastAtan(d, x); @@ -803,6 +944,11 @@ HWY_NOINLINE V CallFastLog(const D d, VecArg x) { return FastLog(d, x); } +template +HWY_NOINLINE V CallFastExp(const D d, VecArg x) { + return FastExp(d, x); +} + } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); diff --git a/hwy/contrib/math/math_test.cc b/hwy/contrib/math/math_test.cc index 32d7d6582e..2d30f8be11 100644 --- a/hwy/contrib/math/math_test.cc +++ b/hwy/contrib/math/math_test.cc @@ -222,6 +222,7 @@ HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T), } double max_actual_rel_error = 0.0; + double max_error_value = 0.0; // Emulation is slower, so cannot afford as many. const UintT kSamplesPerRange = static_cast(AdjustedReps(static_cast(samples))); @@ -248,7 +249,10 @@ HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T), double rel = std::abs(static_cast(actual) - static_cast(expected)) / std::abs(static_cast(expected)); - max_actual_rel_error = HWY_MAX(max_actual_rel_error, rel); + if (ScalarIsNaN(rel) || rel > max_actual_rel_error) { + max_actual_rel_error = rel; + max_error_value = static_cast(value); + } if (rel > max_relative_error) { static int print_count = 0; if (print_count < 10) { @@ -263,8 +267,9 @@ HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T), } } } - fprintf(stderr, "%s: %s max_rel_error %E\n", - hwy::TypeName(T(), Lanes(d)).c_str(), name, max_actual_rel_error); + fprintf(stderr, "%s: %s max_rel_error %E at %E\n", + hwy::TypeName(T(), Lanes(d)).c_str(), name, max_actual_rel_error, + max_error_value); HWY_ASSERT(max_actual_rel_error <= max_relative_error); } @@ -283,6 +288,40 @@ struct TestFastLog { } }; +struct TestFastExp { + template + HWY_NOINLINE void operator()(T, D d) { + if (sizeof(T) == 4) { + // Float Normal Range: [-87.0, +88.0] + // exp(-87) ~= 1.6e-38 (just above min normal 1.17e-38) + TestMathRelative("FastExpNormal", std::exp, CallFastExp, d, + static_cast(-87.0), static_cast(88.0), + 0.0007, 1e7); + + // Float Subnormal Range: [-104.0, -87.0] + // exp(-104) is close to 0. Error is dominated by quantization (1 ULP ~= + // 50% relative error for small values). + TestMath("FastExpSubnormal", std::exp, CallFastExp, d, + static_cast(-FLT_MAX), static_cast(-87.0), 1); + } else { + // Double Normal Range: [-708.0, +706.0] + // exp(-708) ~= 2.2e-308 (min normal 2.22e-308) + TestMathRelative("FastExpNormal", std::exp, CallFastExp, d, + static_cast(-708.0), static_cast(706.0), + 0.0007, 1e7); + + // Double Subnormal Range: [-744.0, -708.0] + // exp(-744) is very small. Quantization error is expected. + TestMath("FastExpSubnormal", std::exp, CallFastExp, d, + static_cast(-DBL_MAX), static_cast(-708.0), 1); + } + } +}; + +HWY_NOINLINE void TestAllFastExp() { + ForFloat3264Types(ForPartialVectors()); +} + HWY_NOINLINE void TestAllFastLog() { ForFloat3264Types(ForPartialVectors()); } @@ -305,6 +344,7 @@ HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog10); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog1p); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog2); HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastLog); +HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastExp); HWY_AFTER_TEST(); } // namespace } // namespace hwy