aboutsummaryrefslogtreecommitdiff
path: root/bencode/bencode.go
blob: 21b23ee5f994c2eb30d6de48cfbfd64cb2330d1c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package bencode

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"reflect"
	"sort"
	"time"
)

// mapping / limitations
// dict -> struct
// list -> aray of same type

func Marshal(v interface{}) ([]byte, error) {
	var out bytes.Buffer
	val := reflect.ValueOf(v)
	err := marshalField(&out, val)
	return out.Bytes(), err
}

func marshalField(out io.Writer, v reflect.Value) error {
	if !v.IsValid() {
		return errors.New("ivalid value")
	}
	switch v.Kind() {
	case reflect.String:
		marshalString(out, v.String())
	case reflect.Int, reflect.Int64:
		marshalInt(out, v.Int())
	case reflect.Slice:
		marshalList(out, v)
	case reflect.Struct:
		marshalDict(out, v)
	case reflect.Bool:
		if v.Bool() {
			marshalInt(out, 1)
		} else {
			marshalInt(out, 0)
		}
	}
	return nil
}

func isZero(v reflect.Value) bool {
	switch v.Kind() {
	case reflect.Int, reflect.Int64:
		return v.Int() == 0
	case reflect.String, reflect.Slice:
		return v.Len() == 0
	case reflect.Interface:
		if t, ok := v.Interface().(time.Time); ok {
			return t.IsZero()
		}
	case reflect.Ptr:
		return v.IsNil()
	}
	return false
}

type byName []reflect.StructField

func (n byName) Len() int           { return len(n) }
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
func (n byName) Swap(i, j int)      { n[i], n[j] = n[j], n[i] }

func marshalDict(out io.Writer, v reflect.Value) {
	switch val := v.Interface().(type) {
	case time.Time:
		marshalInt(out, val.Unix())
	default:
		t := v.Type()
		fields := make([]reflect.StructField, t.NumField())
		for i := 0; i < t.NumField(); i++ {
			fields[i] = t.Field(i)
		}
		sort.Sort(byName(fields))

		io.WriteString(out, "d")
		for _, n := range fields {
			tag := n.Tag.Get("bencode")
			if tag == "-" {
				continue
			}
			name, param := parseTag(tag)
			if name == "" {
				name = n.Name
			}
			vf := v.FieldByIndex(n.Index)
			if param == "optional" && isZero(vf) {
				continue
			}
			marshalString(out, name)
			marshalField(out, vf)
		}
		io.WriteString(out, "e")
	}
}

func marshalList(out io.Writer, v reflect.Value) {
	switch v.Type().Elem().Kind() {
	case reflect.Uint8:
		marshalString(out, string(v.Bytes()))
	default:
		io.WriteString(out, "l")
		for i := 0; i < v.Len(); i++ {
			marshalField(out, v.Index(i))
		}
		io.WriteString(out, "e")
	}
}

func marshalString(out io.Writer, s string) {
	fmt.Fprintf(out, "%d:%s", len(s), s)
}

func marshalInt(out io.Writer, i int64) {
	fmt.Fprintf(out, "i%de", i)
}