package bencode import ( "bytes" "crypto/sha1" "errors" "fmt" "io" "reflect" "sort" "strconv" "strings" "time" ) // mapping / limitations // dict -> struct // list -> aray of same type var ErrValue = errors.New("Unexpected value") func Marshal(v interface{}) ([]byte, error) { var out bytes.Buffer val := reflect.ValueOf(v) err := marshalField(&out, val) return out.Bytes(), err } func marshalField(out io.Writer, v reflect.Value) error { if !v.IsValid() { return errors.New("ivalid value") } switch v.Kind() { case reflect.String: marshalString(out, v.String()) case reflect.Int: marshalInt(out, v.Int()) case reflect.Slice: marshalList(out, v) case reflect.Struct: marshalDict(out, v) case reflect.Bool: if v.Bool() { marshalInt(out, 1) } else { marshalInt(out, 0) } } return nil } func isEmpty(v reflect.Value) bool { switch v.Kind() { case reflect.Int: return v.Int() == 0 case reflect.String, reflect.Slice: return v.Len() == 0 case reflect.Interface, reflect.Ptr: return v.IsNil() } return false } type byName []reflect.StructField func (n byName) Len() int { return len(n) } func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name } func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] } func marshalDict(out io.Writer, v reflect.Value) { switch val := v.Interface().(type) { case time.Time: marshalInt(out, val.Unix()) default: t := v.Type() fields := make([]reflect.StructField, t.NumField()) for i := 0; i < t.NumField(); i++ { fields[i] = t.Field(i) } sort.Sort(byName(fields)) io.WriteString(out, "d") for _, n := range fields { tag := n.Tag.Get("bencode") if tag == "-" { continue } name, param := parseTag(tag) if name == "" { name = n.Name } vf := v.FieldByIndex(n.Index) if param == "optional" && isEmpty(vf) { continue } marshalString(out, name) marshalField(out, vf) } io.WriteString(out, "e") } } func marshalList(out io.Writer, v reflect.Value) { switch v.Type().Elem().Kind() { case reflect.Uint8: marshalString(out, string(v.Bytes())) default: io.WriteString(out, "l") for i := 0; i < v.Len(); i++ { marshalField(out, v.Index(i)) } io.WriteString(out, "e") } } func marshalString(out io.Writer, s string) { fmt.Fprintf(out, "%d:%s", len(s), s) } func marshalInt(out io.Writer, i int64) { fmt.Fprintf(out, "i%de", i) } func parseTag(tag string) (string, string) { if i := strings.Index(tag, ","); i != -1 { return tag[:i], tag[i+1:] } return tag, "" } type decodeState struct { data []byte off int } func Unmarshal(data []byte, v interface{}) error { val := reflect.ValueOf(v) if val.Kind() != reflect.Ptr { return errors.New("non-pointer passed to Unmarshal") } d := decodeState{data: data} err := d.unmarshalField(val.Elem()) return err } func (d *decodeState) unmarshalField(v reflect.Value) error { switch d.data[d.off] { case 'd': return d.unmarshalDict(v) case 'i': return d.unmarshalInt(v) case 'l': return d.unmarshalList(v) case 'e': return ErrValue default: return d.unmarshalString(v) } return nil } func findKey(key string, v reflect.Value) reflect.Value { if v.CanSet() { t := v.Type() for i := 0; i < t.NumField(); i++ { f := t.Field(i) tag := f.Tag.Get("bencode") name, _ := parseTag(tag) if key == name || key == f.Name || key == strings.ToLower(f.Name) { return v.Field(i) } } } return reflect.Value{} } func (d *decodeState) unmarshalDict(v reflect.Value) error { if d.data[d.off] != 'd' { return ErrValue } d.off++ var infoOff, infoEnd int for d.data[d.off] != 'e' { key, n := parseString(d.data[d.off:]) d.off += n f := findKey(key, v) if key == "info" { infoOff = d.off } d.unmarshalField(f) if key == "info" { infoEnd = d.off } } if v.CanSet() { if ih := v.FieldByName("InfoHash"); ih.IsValid() && infoEnd > infoOff { sum := sha1.Sum(d.data[infoOff:infoEnd]) ih.SetBytes(sum[:]) } } d.off++ return nil } func (d *decodeState) unmarshalList(v reflect.Value) error { if d.data[d.off] != 'l' { return ErrValue } d.off++ for i := 0; d.data[d.off] != 'e'; i++ { if v.CanSet() { if i >= v.Cap() { newcap := v.Cap() + v.Cap()/2 if newcap < 4 { newcap = 4 } newv := reflect.MakeSlice(v.Type(), v.Len(), newcap) reflect.Copy(newv, v) v.Set(newv) } if i >= v.Len() { v.SetLen(i + 1) } d.unmarshalField(v.Index(i)) } else { d.unmarshalField(reflect.Value{}) } } d.off++ return nil } func (d *decodeState) unmarshalString(v reflect.Value) error { s, n := parseString(d.data[d.off:]) d.off += n if v.CanSet() { switch v.Kind() { case reflect.Slice: v.SetBytes([]byte(s)) default: v.SetString(s) } } return nil } func parseString(data []byte) (string, int) { switch data[0] { case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': default: panic("not a string") } i := bytes.IndexByte(data, ':') if i < 0 || i > 20 { // len(18446744073709551615) == 20 (MaxUint64) panic("separator missing") } size, err := strconv.Atoi(string(data[:i])) if err != nil { panic(err) } end := size + i + 1 return string(data[i+1 : end]), end } func (d *decodeState) unmarshalInt(v reflect.Value) error { i, n := parseInt(d.data[d.off:]) d.off += n if v.CanSet() { switch v.Interface().(type) { case time.Time: t := time.Unix(i, 0) v.Set(reflect.ValueOf(t)) case bool: v.SetBool(i == 1) default: v.SetInt(i) } } return nil } func parseInt(data []byte) (int64, int) { if data[0] != 'i' { panic("not an int") } end := bytes.IndexByte(data, 'e') i, err := strconv.Atoi(string(data[1:end])) if err != nil { panic(err) } return int64(i), end + 1 }