~nskaggs/+junk/xenial-test

« back to all changes in this revision

Viewing changes to src/github.com/juju/httprequest/cmd/httprequest-generate-client/main.go

  • Committer: Nicholas Skaggs
  • Date: 2016-10-24 20:56:05 UTC
  • Revision ID: nicholas.skaggs@canonical.com-20161024205605-z8lta0uvuhtxwzwl
Initi with beta15

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
// +build go1.6
 
2
 
 
3
package main
 
4
 
 
5
import (
 
6
        "bytes"
 
7
        "flag"
 
8
        "fmt"
 
9
        "go/ast"
 
10
        "go/build"
 
11
        "go/format"
 
12
        "go/parser"
 
13
        "io/ioutil"
 
14
        "os"
 
15
        "strings"
 
16
        "text/template"
 
17
 
 
18
        "golang.org/x/tools/go/loader"
 
19
        "go/types"
 
20
        "gopkg.in/errgo.v1"
 
21
)
 
22
 
 
23
// TODO:
 
24
// - generate exported types if the parameter/response types aren't exported?
 
25
// - deal with literal interface and struct types.
 
26
// - copy doc comments from server methods.
 
27
 
 
28
func main() {
 
29
        flag.Usage = func() {
 
30
                fmt.Fprintf(os.Stderr, "usage: httprequest-generate server-package server-type client-type\n")
 
31
                os.Exit(2)
 
32
        }
 
33
        flag.Parse()
 
34
        if flag.NArg() != 3 {
 
35
                flag.Usage()
 
36
        }
 
37
 
 
38
        serverPkg, serverType, clientType := flag.Arg(0), flag.Arg(1), flag.Arg(2)
 
39
 
 
40
        if err := generate(serverPkg, serverType, clientType); err != nil {
 
41
                fmt.Fprintf(os.Stderr, "%v\n", err)
 
42
                os.Exit(1)
 
43
        }
 
44
}
 
45
 
 
46
type templateArg struct {
 
47
        PkgName    string
 
48
        Imports    []string
 
49
        Methods    []method
 
50
        ClientType string
 
51
}
 
52
 
 
53
var code = template.Must(template.New("").Parse(`
 
54
// The code in this file was automatically generated by running httprequest-generate-client.
 
55
// DO NOT EDIT
 
56
 
 
57
package {{.PkgName}}
 
58
import (
 
59
        {{range .Imports}}{{printf "%q" .}}
 
60
        {{end}}
 
61
)
 
62
 
 
63
type {{.ClientType}} struct {
 
64
        Client httprequest.Client
 
65
}
 
66
 
 
67
{{range .Methods}}
 
68
{{if .RespType}}
 
69
        {{.Doc}}
 
70
        func (c *{{$.ClientType}}) {{.Name}}(p *{{.ParamType}}) ({{.RespType}}, error) {
 
71
                var r {{.RespType}}
 
72
                err := c.Client.Call(p, &r)
 
73
                return r, err
 
74
        }
 
75
{{else}}
 
76
        {{.Doc}}
 
77
        func (c *{{$.ClientType}}) {{.Name}}(p *{{.ParamType}}) (error) {
 
78
                return c.Client.Call(p, nil)
 
79
        }
 
80
{{end}}
 
81
{{end}}
 
82
`))
 
83
 
 
84
func generate(serverPkgPath, serverType, clientType string) error {
 
85
        currentDir, err := os.Getwd()
 
86
        if err != nil {
 
87
                return err
 
88
        }
 
89
        localPkg, err := build.Import(".", currentDir, 0)
 
90
        if err != nil {
 
91
                return errgo.Notef(err, "cannot open package in current directory")
 
92
        }
 
93
        serverPkg, err := build.Import(serverPkgPath, currentDir, 0)
 
94
        if err != nil {
 
95
                return errgo.Notef(err, "cannot open %q", serverPkgPath)
 
96
        }
 
97
 
 
98
        methods, imports, err := serverMethods(serverPkg.ImportPath, serverType, localPkg.ImportPath)
 
99
        if err != nil {
 
100
                return errgo.Mask(err)
 
101
        }
 
102
        arg := templateArg{
 
103
                Imports:    imports,
 
104
                Methods:    methods,
 
105
                PkgName:    localPkg.Name,
 
106
                ClientType: clientType,
 
107
        }
 
108
        var buf bytes.Buffer
 
109
        if err := code.Execute(&buf, arg); err != nil {
 
110
                return errgo.Mask(err)
 
111
        }
 
112
        data, err := format.Source(buf.Bytes())
 
113
        if err != nil {
 
114
                return errgo.Notef(err, "cannot format source")
 
115
        }
 
116
        if err := writeOutput(data, clientType); err != nil {
 
117
                return errgo.Mask(err)
 
118
        }
 
119
        return nil
 
120
}
 
121
 
 
122
func writeOutput(data []byte, clientType string) error {
 
123
        filename := strings.ToLower(clientType) + "_generated.go"
 
124
        if err := ioutil.WriteFile(filename, data, 0644); err != nil {
 
125
                return errgo.Mask(err)
 
126
        }
 
127
        return nil
 
128
}
 
129
 
 
130
type method struct {
 
131
        Name      string
 
132
        Doc       string
 
133
        ParamType string
 
134
        RespType  string
 
135
}
 
136
 
 
137
func serverMethods(serverPkg, serverType, localPkg string) ([]method, []string, error) {
 
138
        cfg := loader.Config{
 
139
                TypeCheckFuncBodies: func(string) bool {
 
140
                        return false
 
141
                },
 
142
                ImportPkgs: map[string]bool{
 
143
                        serverPkg: false, // false means don't load tests.
 
144
                },
 
145
                ParserMode: parser.ParseComments,
 
146
        }
 
147
        prog, err := cfg.Load()
 
148
        if err != nil {
 
149
                return nil, nil, errgo.Notef(err, "cannot load %q", serverPkg)
 
150
        }
 
151
        pkgInfo := prog.Imported[serverPkg]
 
152
        if pkgInfo == nil {
 
153
                return nil, nil, errgo.Newf("cannot find %q in imported code", serverPkg)
 
154
        }
 
155
        pkg := pkgInfo.Pkg
 
156
        obj := pkg.Scope().Lookup(serverType)
 
157
        if obj == nil {
 
158
                return nil, nil, errgo.Newf("type %s not found in %s", serverType, serverPkg)
 
159
        }
 
160
        objTypeName, ok := obj.(*types.TypeName)
 
161
        if !ok {
 
162
                return nil, nil, errgo.Newf("%s is not a type", serverType)
 
163
        }
 
164
        // Use the pointer type to get as many methods as possible.
 
165
        ptrObjType := types.NewPointer(objTypeName.Type())
 
166
 
 
167
        imports := map[string]string{
 
168
                "github.com/juju/httprequest": "httprequest",
 
169
                localPkg:                      "",
 
170
        }
 
171
        var methods []method
 
172
        mset := types.NewMethodSet(ptrObjType)
 
173
        for i := 0; i < mset.Len(); i++ {
 
174
                sel := mset.At(i)
 
175
                if !sel.Obj().Exported() {
 
176
                        continue
 
177
                }
 
178
                name := sel.Obj().Name()
 
179
                if name == "Close" {
 
180
                        continue
 
181
                }
 
182
                ptype, rtype, err := parseMethodType(sel.Type().(*types.Signature))
 
183
                if err != nil {
 
184
                        fmt.Fprintf(os.Stderr, "ignoring method %s: %v\n", name, err)
 
185
                        continue
 
186
                }
 
187
                comment := docComment(prog, sel)
 
188
                methods = append(methods, method{
 
189
                        Name:      name,
 
190
                        Doc:       comment,
 
191
                        ParamType: typeStr(ptype, imports),
 
192
                        RespType:  typeStr(rtype, imports),
 
193
                })
 
194
        }
 
195
        delete(imports, localPkg)
 
196
        var allImports []string
 
197
        for path := range imports {
 
198
                allImports = append(allImports, path)
 
199
        }
 
200
        return methods, allImports, nil
 
201
}
 
202
 
 
203
// docComment returns the doc comment for the method referred to
 
204
// by the given selection.
 
205
func docComment(prog *loader.Program, sel *types.Selection) string {
 
206
        obj := sel.Obj()
 
207
        tokFile := prog.Fset.File(obj.Pos())
 
208
        if tokFile == nil {
 
209
                panic("no file found for method")
 
210
        }
 
211
        filename := tokFile.Name()
 
212
        for _, pkgInfo := range prog.AllPackages {
 
213
                for _, f := range pkgInfo.Files {
 
214
                        if tokFile := prog.Fset.File(f.Pos()); tokFile == nil || tokFile.Name() != filename {
 
215
                                continue
 
216
                        }
 
217
                        // We've found the file we're looking for. Now traverse all
 
218
                        // top level declarations looking for the right function declaration.
 
219
                        for _, decl := range f.Decls {
 
220
                                fdecl, ok := decl.(*ast.FuncDecl)
 
221
                                if ok && fdecl.Name.Pos() == obj.Pos() {
 
222
                                        // Found it!
 
223
                                        return commentStr(fdecl.Doc)
 
224
                                }
 
225
                        }
 
226
                }
 
227
        }
 
228
        panic("method declaration not found")
 
229
}
 
230
 
 
231
func commentStr(c *ast.CommentGroup) string {
 
232
        if c == nil {
 
233
                return ""
 
234
        }
 
235
        var b []byte
 
236
        for i, cc := range c.List {
 
237
                if i > 0 {
 
238
                        b = append(b, '\n')
 
239
                }
 
240
                b = append(b, cc.Text...)
 
241
        }
 
242
        return string(b)
 
243
}
 
244
 
 
245
// typeStr returns the type string to be used when using the
 
246
// given type. It adds any needed import paths to the given
 
247
// imports map (map from package path to package id).
 
248
func typeStr(t types.Type, imports map[string]string) string {
 
249
        if t == nil {
 
250
                return ""
 
251
        }
 
252
        qualify := func(pkg *types.Package) string {
 
253
                if name, ok := imports[pkg.Path()]; ok {
 
254
                        return name
 
255
                }
 
256
                name := pkg.Name()
 
257
                // Make sure we're not duplicating the name.
 
258
                // TODO if we are, make a new non-duplicated version.
 
259
                for oldPkg, oldName := range imports {
 
260
                        if oldName == name {
 
261
                                panic(errgo.Newf("duplicate package name %s vs %s", pkg.Path(), oldPkg))
 
262
                        }
 
263
                }
 
264
                imports[pkg.Path()] = name
 
265
                return name
 
266
        }
 
267
        return types.TypeString(t, qualify)
 
268
}
 
269
 
 
270
func parseMethodType(t *types.Signature) (ptype, rtype types.Type, err error) {
 
271
        mp := t.Params()
 
272
        if mp.Len() != 1 && mp.Len() != 2 {
 
273
                return nil, nil, errgo.New("wrong argument count")
 
274
        }
 
275
        ptype0 := mp.At(mp.Len() - 1).Type()
 
276
        ptype1, ok := ptype0.(*types.Pointer)
 
277
        if !ok {
 
278
                return nil, nil, errgo.New("parameter is not a pointer")
 
279
        }
 
280
        ptype = ptype1.Elem()
 
281
        if _, ok := ptype.Underlying().(*types.Struct); !ok {
 
282
                return nil, nil, errgo.Newf("parameter is %s, not a pointer to struct", ptype1.Elem())
 
283
        }
 
284
        rp := t.Results()
 
285
        if rp.Len() > 2 {
 
286
                return nil, nil, errgo.New("wrong result count")
 
287
        }
 
288
        if rp.Len() == 2 {
 
289
                rtype = rp.At(0).Type()
 
290
        }
 
291
        return ptype, rtype, nil
 
292
}