package bencode import ( "bytes" "errors" "reflect" "strconv" "time" ) const ( tokenDict = 'd' tokenList = 'l' tokenInt = 'i' tokenEnd = 'e' ) type Unmarshaler interface { UnmarshalBencode([]byte) (int, error) } var ErrValue = errors.New("Unexpected value") type decodeState struct { data []byte off int } func Unmarshal(data []byte, v interface{}) (int, error) { val := reflect.ValueOf(v) if val.Kind() != reflect.Ptr { return 0, errors.New("non-pointer passed to Unmarshal") } d := decodeState{data: data} err := d.unmarshalField(val.Elem()) return d.off, err } func (d *decodeState) unmarshalField(v reflect.Value) error { switch d.data[d.off] { case tokenDict: return d.unmarshalDict(v) case tokenInt: return d.unmarshalInt(v) case tokenList: return d.unmarshalList(v) case tokenEnd: return ErrValue default: return d.unmarshalString(v) } return nil } func (d *decodeState) unmarshalDict(v reflect.Value) error { if d.data[d.off] != tokenDict { return ErrValue } rawStart := d.off d.off++ for d.data[d.off] != tokenEnd { key, n := parseString(d.data[d.off:]) d.off += n val := findKey(key, v) if val.CanAddr() && val.Addr().Type().NumMethod() > 0 { if u, ok := val.Addr().Interface().(Unmarshaler); ok { n, err := u.UnmarshalBencode(d.data[d.off:]) if err != nil { return err } d.off += n continue } } d.unmarshalField(val) } d.off++ rawEnd := d.off if raw := v.FieldByName("Raw"); raw.IsValid() && rawEnd > rawStart { raw.SetBytes(d.data[rawStart:rawEnd]) } return nil } func (d *decodeState) unmarshalList(v reflect.Value) error { if d.data[d.off] != tokenList { return ErrValue } d.off++ for i := 0; d.data[d.off] != tokenEnd; i++ { if v.CanSet() { if i >= v.Cap() { newcap := v.Cap() + v.Cap()/2 if newcap < 4 { newcap = 4 } newv := reflect.MakeSlice(v.Type(), v.Len(), newcap) reflect.Copy(newv, v) v.Set(newv) } if i >= v.Len() { v.SetLen(i + 1) } d.unmarshalField(v.Index(i)) } else { d.unmarshalField(reflect.Value{}) } } d.off++ return nil } func (d *decodeState) unmarshalString(v reflect.Value) error { s, n := parseString(d.data[d.off:]) d.off += n if v.CanSet() { switch v.Kind() { case reflect.Slice: v.SetBytes([]byte(s)) case reflect.Array: if len(s) != v.Len() { return ErrValue } for i := 0; i < v.Len(); i++ { v.Index(i).Set(reflect.ValueOf(s[i])) } default: v.SetString(s) } } return nil } func (d *decodeState) unmarshalInt(v reflect.Value) error { i, n := parseInt(d.data[d.off:]) d.off += n if v.CanSet() { switch v.Interface().(type) { case *time.Time: t := time.Unix(i, 0) nv := reflect.New(v.Type()) nv.Elem().Set(reflect.ValueOf(&t)) v.Set(nv.Elem()) case time.Time: t := time.Unix(i, 0) v.Set(reflect.ValueOf(t)) case time.Duration: v.SetInt(i * int64(time.Second)) case bool: v.SetBool(i == 1) default: v.SetInt(i) } } return nil } func parseString(data []byte) (string, int) { switch data[0] { case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': default: panic("not a string") } i := bytes.IndexByte(data, ':') if i < 0 || i > 20 { // len(18446744073709551615) == 20 (MaxUint64) panic("separator missing") } size, err := strconv.Atoi(string(data[:i])) if err != nil { panic(err) } end := size + i + 1 return string(data[i+1 : end]), end } func parseInt(data []byte) (int64, int) { if data[0] != tokenInt { panic("not an int") } end := bytes.IndexByte(data, tokenEnd) i, err := strconv.Atoi(string(data[1:end])) if err != nil { panic(err) } return int64(i), end + 1 }