From ef5f658c66e3cff4b85202e17d22c39d54af1c9b Mon Sep 17 00:00:00 2001 From: Dimitri Sokolyuk Date: Sat, 8 Jul 2017 23:06:41 +0200 Subject: Correct round --- round.go | 38 ++++++++++++++++++++++++++++++++++---- round_test.go | 35 ++++++++++++++++++++++++++++++----- 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/round.go b/round.go index 345f7b0..1962ea6 100644 --- a/round.go +++ b/round.go @@ -6,9 +6,39 @@ import "math" // q = \sgn(y) \left\lfloor \left| y \right| + 0.5 \right\rfloor // = -\sgn(y) \left\lceil -\left| y \right| - 0.5 \right\rceil -// Round a float value to n decimal places -func Round(v float64, n int) float64 { +// RoundN a float value to n decimal places +func RoundN(x float64, n int) float64 { pow := math.Pow(10, float64(n)) - abs := math.Abs(v*pow) + 0.5 - return math.Copysign(math.Floor(abs)/pow, v) + return Round(x*pow) / pow +} + +// Round a float value +// https://www.cockroachlabs.com/blog/rounding-implementations-in-go/ +func Round(x float64) float64 { + const ( + mask = 0x7FF + shift = 64 - 11 - 1 + bias = 1023 + signMask = 1 << 63 + fracMask = (1 << shift) - 1 + halfMask = 1 << (shift - 1) + one = bias << shift + ) + + bits := math.Float64bits(x) + e := uint(bits>>shift) & mask + switch { + case e < bias: + // Round abs(x)<1 including denormals. + bits &= signMask // +-0 + if e == bias-1 { + bits |= one // +-1 + } + case e < bias+shift: + // Round any abs(x)>=1 containing a fractional component [0,1). + e -= bias + bits += halfMask >> e + bits &^= fracMask >> e + } + return math.Float64frombits(bits) } diff --git a/round_test.go b/round_test.go index 13a9051..6033156 100644 --- a/round_test.go +++ b/round_test.go @@ -2,10 +2,11 @@ package float import ( "fmt" + "math" "testing" ) -func TestRound(t *testing.T) { +func TestRound2(t *testing.T) { testCases := []struct { in, out float64 }{ @@ -24,15 +25,39 @@ func TestRound(t *testing.T) { } for _, tc := range testCases { t.Run(fmt.Sprint(tc.in), func(t *testing.T) { - if r := Round(tc.in, 2); r != tc.out { + if r := RoundN(tc.in, 2); r != tc.out { t.Errorf("got %v, want %v", r, tc.out) } }) } } -func BenchmarkRound(b *testing.B) { - for i := 0; i < b.N; i++ { - Round(0.333, 2) +func TestRound(t *testing.T) { + negZero := math.Copysign(0, -1) + testCases := []struct { + in, out float64 + }{ + {-0.49999999999999994, negZero}, // -0.5+epsilon + {-0.5, -1}, + {-0.5000000000000001, -1}, // -0.5-epsilon + {0, 0}, + {0.49999999999999994, 0}, // 0.5-epsilon + {0.5, 1}, + {0.5000000000000001, 1}, // 0.5+epsilon + {1.390671161567e-309, 0}, // denormal + {2.2517998136852485e+15, 2.251799813685249e+15}, // 1 bit fraction + {4.503599627370497e+15, 4.503599627370497e+15}, // large integer + {math.Inf(-1), math.Inf(-1)}, + {math.Inf(1), math.Inf(1)}, + {math.NaN(), math.NaN()}, + {negZero, negZero}, + } + for _, tc := range testCases { + t.Run(fmt.Sprint(tc.in, tc.out), func(t *testing.T) { + r := Round(tc.in) + if math.Float64bits(r) != math.Float64bits(tc.out) { + t.Errorf("got %v, want %v", r, tc.out) + } + }) } } -- cgit v1.2.3