summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--round.go38
-rw-r--r--round_test.go35
2 files changed, 64 insertions, 9 deletions
diff --git a/round.go b/round.go
index 345f7b0..1962ea6 100644
--- a/round.go
+++ b/round.go
@@ -6,9 +6,39 @@ import "math"
// q = \sgn(y) \left\lfloor \left| y \right| + 0.5 \right\rfloor
// = -\sgn(y) \left\lceil -\left| y \right| - 0.5 \right\rceil
-// Round a float value to n decimal places
-func Round(v float64, n int) float64 {
+// RoundN a float value to n decimal places
+func RoundN(x float64, n int) float64 {
pow := math.Pow(10, float64(n))
- abs := math.Abs(v*pow) + 0.5
- return math.Copysign(math.Floor(abs)/pow, v)
+ return Round(x*pow) / pow
+}
+
+// Round a float value
+// https://www.cockroachlabs.com/blog/rounding-implementations-in-go/
+func Round(x float64) float64 {
+ const (
+ mask = 0x7FF
+ shift = 64 - 11 - 1
+ bias = 1023
+ signMask = 1 << 63
+ fracMask = (1 << shift) - 1
+ halfMask = 1 << (shift - 1)
+ one = bias << shift
+ )
+
+ bits := math.Float64bits(x)
+ e := uint(bits>>shift) & mask
+ switch {
+ case e < bias:
+ // Round abs(x)<1 including denormals.
+ bits &= signMask // +-0
+ if e == bias-1 {
+ bits |= one // +-1
+ }
+ case e < bias+shift:
+ // Round any abs(x)>=1 containing a fractional component [0,1).
+ e -= bias
+ bits += halfMask >> e
+ bits &^= fracMask >> e
+ }
+ return math.Float64frombits(bits)
}
diff --git a/round_test.go b/round_test.go
index 13a9051..6033156 100644
--- a/round_test.go
+++ b/round_test.go
@@ -2,10 +2,11 @@ package float
import (
"fmt"
+ "math"
"testing"
)
-func TestRound(t *testing.T) {
+func TestRound2(t *testing.T) {
testCases := []struct {
in, out float64
}{
@@ -24,15 +25,39 @@ func TestRound(t *testing.T) {
}
for _, tc := range testCases {
t.Run(fmt.Sprint(tc.in), func(t *testing.T) {
- if r := Round(tc.in, 2); r != tc.out {
+ if r := RoundN(tc.in, 2); r != tc.out {
t.Errorf("got %v, want %v", r, tc.out)
}
})
}
}
-func BenchmarkRound(b *testing.B) {
- for i := 0; i < b.N; i++ {
- Round(0.333, 2)
+func TestRound(t *testing.T) {
+ negZero := math.Copysign(0, -1)
+ testCases := []struct {
+ in, out float64
+ }{
+ {-0.49999999999999994, negZero}, // -0.5+epsilon
+ {-0.5, -1},
+ {-0.5000000000000001, -1}, // -0.5-epsilon
+ {0, 0},
+ {0.49999999999999994, 0}, // 0.5-epsilon
+ {0.5, 1},
+ {0.5000000000000001, 1}, // 0.5+epsilon
+ {1.390671161567e-309, 0}, // denormal
+ {2.2517998136852485e+15, 2.251799813685249e+15}, // 1 bit fraction
+ {4.503599627370497e+15, 4.503599627370497e+15}, // large integer
+ {math.Inf(-1), math.Inf(-1)},
+ {math.Inf(1), math.Inf(1)},
+ {math.NaN(), math.NaN()},
+ {negZero, negZero},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprint(tc.in, tc.out), func(t *testing.T) {
+ r := Round(tc.in)
+ if math.Float64bits(r) != math.Float64bits(tc.out) {
+ t.Errorf("got %v, want %v", r, tc.out)
+ }
+ })
}
}