aboutsummaryrefslogtreecommitdiff
path: root/bencode
diff options
context:
space:
mode:
authorDimitri Sokolyuk <demon@dim13.org>2016-06-12 06:47:50 +0200
committerDimitri Sokolyuk <demon@dim13.org>2016-06-12 06:47:50 +0200
commitf5b44cf1811c768dc268f6097422bb50133ce913 (patch)
tree7a79d2634cf8db8ffc09bac503cdce0d4032fe29 /bencode
parent1a197f171413d85b04b2d48853acb594b345dbc5 (diff)
Add special case for info_hash
Diffstat (limited to 'bencode')
-rw-r--r--bencode/bencode.go23
-rw-r--r--bencode/bencode_test.go35
2 files changed, 43 insertions, 15 deletions
diff --git a/bencode/bencode.go b/bencode/bencode.go
index 01467f9..f4f6a85 100644
--- a/bencode/bencode.go
+++ b/bencode/bencode.go
@@ -2,6 +2,7 @@ package bencode
import (
"bytes"
+ "crypto/sha1"
"errors"
"fmt"
"io"
@@ -106,17 +107,21 @@ func parseTag(tag string) (string, string) {
}
type decodeState struct {
- data []byte
- off int
+ data []byte
+ off int
+ infoOff int
+ infoEnd int
}
-func Unmarshal(data []byte, v interface{}) error {
+func Unmarshal(data []byte, v interface{}) ([]byte, error) {
val := reflect.ValueOf(v)
if val.Kind() != reflect.Ptr {
- return errors.New("non-pointer passed to Unmarshal")
+ return nil, errors.New("non-pointer passed to Unmarshal")
}
- d := decodeState{data, 0}
- return d.unmarshalField(val.Elem())
+ d := decodeState{data: data}
+ err := d.unmarshalField(val.Elem())
+ sum := sha1.Sum(d.data[d.infoOff:d.infoEnd])
+ return sum[:], err
}
func (d *decodeState) unmarshalField(v reflect.Value) error {
@@ -150,7 +155,13 @@ func (d *decodeState) unmarshalDict(v reflect.Value) {
tag := f.Tag.Get("bencode")
name, _ := parseTag(tag)
if name == key || f.Name == key {
+ if key == "info" {
+ d.infoOff = d.off
+ }
d.unmarshalField(v.Field(i))
+ if key == "info" {
+ d.infoEnd = d.off
+ }
break
}
}
diff --git a/bencode/bencode_test.go b/bencode/bencode_test.go
index 8f3c43f..41c1bff 100644
--- a/bencode/bencode_test.go
+++ b/bencode/bencode_test.go
@@ -1,8 +1,8 @@
package bencode
import (
+ "encoding/hex"
"io/ioutil"
- "path/filepath"
"testing"
"time"
@@ -45,21 +45,38 @@ func TestParseInt(t *testing.T) {
}
}
+var testCase = []struct {
+ Torrent string
+ InfoHash string
+}{
+ {
+ Torrent: "../examples/OpenBSD_5.9_amd64_install59.iso-2016-03-29-0449.torrent",
+ InfoHash: "e840038dea1998c39614dcd28594501df02bd32d",
+ },
+ {
+ Torrent: "../examples/OpenBSD_songs_ogg-2016-03-25-0127.torrent",
+ InfoHash: "1aa5af9f6f533961a65169bdeeb801e611724d32",
+ },
+ {
+ Torrent: "../examples/debian-8.5.0-amd64-netinst.iso.torrent",
+ InfoHash: "47b9ad52c009f3bd562ffc6da40e5c55d3fb47f3",
+ },
+}
+
func TestUnmarshal(t *testing.T) {
- files, err := filepath.Glob("../examples/*.torrent")
- if err != nil {
- t.Error(err)
- }
- for _, file := range files {
+ for _, tc := range testCase {
var tor meta.Torrent
- body, err := ioutil.ReadFile(file)
+ body, err := ioutil.ReadFile(tc.Torrent)
if err != nil {
t.Error(err)
}
- err = Unmarshal(body, &tor)
+ sum, err := Unmarshal(body, &tor)
if err != nil {
t.Error(err)
}
- t.Logf("%+v\n", tor)
+ h := hex.EncodeToString(sum)
+ if h != tc.InfoHash {
+ t.Error("got", h, "expected", tc.InfoHash)
+ }
}
}