aboutsummaryrefslogtreecommitdiff
path: root/bencode/bdecode.go
diff options
context:
space:
mode:
authorDimitri Sokolyuk <demon@dim13.org>2016-07-11 02:11:31 +0200
committerDimitri Sokolyuk <demon@dim13.org>2016-07-11 02:11:31 +0200
commita4719c46bda4b3f57ba176f669fdf1ff37e555cc (patch)
tree61bc6f9312873dfe623e7e2481a7fcfe8df3abdf /bencode/bdecode.go
parentf5293d65083b5b9e68b0690f59555c9b34512a69 (diff)
split
Diffstat (limited to 'bencode/bdecode.go')
-rw-r--r--bencode/bdecode.go167
1 files changed, 167 insertions, 0 deletions
diff --git a/bencode/bdecode.go b/bencode/bdecode.go
new file mode 100644
index 0000000..4c7302d
--- /dev/null
+++ b/bencode/bdecode.go
@@ -0,0 +1,167 @@
+package bencode
+
+import (
+ "bytes"
+ "crypto/sha1"
+ "errors"
+ "reflect"
+ "strconv"
+ "time"
+)
+
+var ErrValue = errors.New("Unexpected value")
+
+type decodeState struct {
+ data []byte
+ off int
+}
+
+func Unmarshal(data []byte, v interface{}) error {
+ val := reflect.ValueOf(v)
+ if val.Kind() != reflect.Ptr {
+ return errors.New("non-pointer passed to Unmarshal")
+ }
+ d := decodeState{data: data}
+ err := d.unmarshalField(val.Elem())
+ return err
+}
+
+func (d *decodeState) unmarshalField(v reflect.Value) error {
+ switch d.data[d.off] {
+ case 'd':
+ return d.unmarshalDict(v)
+ case 'i':
+ return d.unmarshalInt(v)
+ case 'l':
+ return d.unmarshalList(v)
+ case 'e':
+ return ErrValue
+ default:
+ return d.unmarshalString(v)
+ }
+ return nil
+}
+
+func (d *decodeState) unmarshalDict(v reflect.Value) error {
+ if d.data[d.off] != 'd' {
+ return ErrValue
+ }
+ d.off++
+
+ var infoOff, infoEnd int
+ for d.data[d.off] != 'e' {
+ key, n := parseString(d.data[d.off:])
+ d.off += n
+ if key == "info" {
+ infoOff = d.off
+ }
+ d.unmarshalField(findKey(key, v))
+ if key == "info" {
+ infoEnd = d.off
+ }
+ }
+
+ if v.CanSet() {
+ if ih := v.FieldByName("InfoHash"); ih.IsValid() && infoEnd > infoOff {
+ sum := sha1.Sum(d.data[infoOff:infoEnd])
+ ih.SetBytes(sum[:])
+ }
+ }
+ d.off++
+ return nil
+}
+
+func (d *decodeState) unmarshalList(v reflect.Value) error {
+ if d.data[d.off] != 'l' {
+ return ErrValue
+ }
+ d.off++
+
+ for i := 0; d.data[d.off] != 'e'; i++ {
+ if v.CanSet() {
+ if i >= v.Cap() {
+ newcap := v.Cap() + v.Cap()/2
+ if newcap < 4 {
+ newcap = 4
+ }
+ newv := reflect.MakeSlice(v.Type(), v.Len(), newcap)
+ reflect.Copy(newv, v)
+ v.Set(newv)
+ }
+ if i >= v.Len() {
+ v.SetLen(i + 1)
+ }
+ d.unmarshalField(v.Index(i))
+ } else {
+ d.unmarshalField(reflect.Value{})
+ }
+ }
+ d.off++
+ return nil
+}
+
+func (d *decodeState) unmarshalString(v reflect.Value) error {
+ s, n := parseString(d.data[d.off:])
+ d.off += n
+ if v.CanSet() {
+ switch v.Kind() {
+ case reflect.Slice:
+ v.SetBytes([]byte(s))
+ default:
+ v.SetString(s)
+ }
+ }
+ return nil
+}
+
+func (d *decodeState) unmarshalInt(v reflect.Value) error {
+ i, n := parseInt(d.data[d.off:])
+ d.off += n
+ if v.CanSet() {
+ switch v.Interface().(type) {
+ case *time.Time:
+ t := time.Unix(i, 0)
+ nv := reflect.New(v.Type())
+ nv.Elem().Set(reflect.ValueOf(&t))
+ v.Set(nv.Elem())
+ case time.Time:
+ t := time.Unix(i, 0)
+ v.Set(reflect.ValueOf(t))
+ case bool:
+ v.SetBool(i == 1)
+ default:
+ v.SetInt(i)
+ }
+ }
+ return nil
+}
+
+func parseString(data []byte) (string, int) {
+ switch data[0] {
+ case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
+ default:
+ panic("not a string")
+ }
+ i := bytes.IndexByte(data, ':')
+ if i < 0 || i > 20 { // len(18446744073709551615) == 20 (MaxUint64)
+ panic("separator missing")
+ }
+ size, err := strconv.Atoi(string(data[:i]))
+ if err != nil {
+ panic(err)
+ }
+ end := size + i + 1
+ return string(data[i+1 : end]), end
+}
+
+func parseInt(data []byte) (int64, int) {
+ if data[0] != 'i' {
+ panic("not an int")
+ }
+ end := bytes.IndexByte(data, 'e')
+ i, err := strconv.Atoi(string(data[1:end]))
+ if err != nil {
+ panic(err)
+ }
+ return int64(i), end + 1
+}