aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--file.go56
-rw-r--r--file_test.go23
-rw-r--r--gen.go4
3 files changed, 57 insertions, 26 deletions
diff --git a/file.go b/file.go
index e8218d8..0497c09 100644
--- a/file.go
+++ b/file.go
@@ -36,45 +36,58 @@ func checkComment(comment string) error {
return nil
}
-func Parse(r io.Reader) (File, error) {
+func (f *File) ReadFrom(r io.Reader) error {
buf := bufio.NewReader(r)
comment, err := buf.ReadString('\n')
if err != nil {
- return File{}, err
+ return err
}
if err := checkComment(comment); err != nil {
- return File{}, err
+ return err
}
- comment = comment[len(commentHdr):]
+ f.Comment = strings.TrimSpace(comment[len(commentHdr):])
b64, err := buf.ReadBytes('\n')
if err != nil {
- return File{}, err
+ return err
}
+ f.B64 = bytes.TrimSpace(b64)
- body, err := ioutil.ReadAll(buf)
+ f.Body, err = ioutil.ReadAll(buf)
if err != nil {
- return File{}, err
+ return err
}
- return File{
- Comment: strings.TrimSpace(comment),
- B64: bytes.TrimSpace(b64),
- Body: body,
- }, nil
+ return nil
+}
+
+func Parse(b []byte) (*File, error) {
+ r := bytes.NewReader(b)
+ f := new(File)
+ return f, f.ReadFrom(r)
}
-func ParseFile(fname string) (File, error) {
+func ParseFile(fname string) (*File, error) {
fd, err := os.Open(fname)
if err != nil {
- return File{}, err
+ return nil, err
}
defer fd.Close()
- return Parse(fd)
+ f := new(File)
+ return f, f.ReadFrom(fd)
}
-func (f File) Encode(w io.Writer) error {
+func (f File) Bytes() ([]byte, error) {
+ buf := new(bytes.Buffer)
+ err := f.WriteTo(buf)
+ if err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+func (f File) WriteTo(w io.Writer) error {
fmt.Fprintf(w, "%v%v\n", commentHdr, f.Comment)
fmt.Fprintf(w, "%v\n", string(f.B64))
if f.Body != nil {
@@ -83,11 +96,16 @@ func (f File) Encode(w io.Writer) error {
return nil
}
-func (f File) EncodeFile(fname string) error {
- fd, err := os.Create(fname)
+const (
+ SecMode os.FileMode = 0600
+ PubMode os.FileMode = 0644
+)
+
+func (f File) WriteFile(fname string, perm os.FileMode) error {
+ fd, err := os.OpenFile(fname, os.O_WRONLY|os.O_CREATE, perm)
if err != nil {
return err
}
defer fd.Close()
- return f.Encode(fd)
+ return f.WriteTo(fd)
}
diff --git a/file_test.go b/file_test.go
index 7ede9fa..acacb2e 100644
--- a/file_test.go
+++ b/file_test.go
@@ -2,6 +2,7 @@ package main
import (
"bytes"
+ "io/ioutil"
"path"
"testing"
)
@@ -16,14 +17,26 @@ func TestParseFile(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
- f, err := ParseFile(path.Join("testdata", tc))
+ fileName := path.Join("testdata", tc)
+
+ body, err := ioutil.ReadFile(fileName)
+ if err != nil {
+ t.Error(err)
+ }
+
+ f, err := Parse(body)
if err != nil {
t.Error(err)
}
- t.Logf("%+v", f)
- buf := new(bytes.Buffer)
- f.Encode(buf)
- t.Logf("%v", buf.String())
+
+ res, err := f.Bytes()
+ if err != nil {
+ t.Error(err)
+ }
+
+ if !bytes.Equal(res, body) {
+ t.Errorf("got %v, want %v", res, body)
+ }
})
}
}
diff --git a/gen.go b/gen.go
index 4c52cc9..4b55b74 100644
--- a/gen.go
+++ b/gen.go
@@ -33,7 +33,7 @@ func Generate(pubkeyfile, seckeyfile, comment string, rounds int) error {
Comment: comment + " secret key",
B64: sb64,
}
- if err := sfile.EncodeFile(seckeyfile); err != nil {
+ if err := sfile.WriteFile(seckeyfile, SecMode); err != nil {
return err
}
@@ -50,7 +50,7 @@ func Generate(pubkeyfile, seckeyfile, comment string, rounds int) error {
Comment: comment + " public key",
B64: pb64,
}
- if err := pfile.EncodeFile(pubkeyfile); err != nil {
+ if err := pfile.WriteFile(pubkeyfile, PubMode); err != nil {
return err
}