~nskaggs/+junk/xenial-test

« back to all changes in this revision

Viewing changes to src/github.com/lxc/lxd/lxd/devlxd.go

  • Committer: Nicholas Skaggs
  • Date: 2016-10-24 20:56:05 UTC
  • Revision ID: nicholas.skaggs@canonical.com-20161024205605-z8lta0uvuhtxwzwl
Initi with beta15

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
package main
 
2
 
 
3
import (
 
4
        "fmt"
 
5
        "io/ioutil"
 
6
        "net"
 
7
        "net/http"
 
8
        "os"
 
9
        "path"
 
10
        "reflect"
 
11
        "regexp"
 
12
        "strconv"
 
13
        "strings"
 
14
        "unsafe"
 
15
 
 
16
        "github.com/gorilla/mux"
 
17
 
 
18
        "github.com/lxc/lxd/shared"
 
19
)
 
20
 
 
21
type devLxdResponse struct {
 
22
        content interface{}
 
23
        code    int
 
24
        ctype   string
 
25
}
 
26
 
 
27
func okResponse(ct interface{}, ctype string) *devLxdResponse {
 
28
        return &devLxdResponse{ct, http.StatusOK, ctype}
 
29
}
 
30
 
 
31
type devLxdHandler struct {
 
32
        path string
 
33
 
 
34
        /*
 
35
         * This API will have to be changed slightly when we decide to support
 
36
         * websocket events upgrading, but since we don't have events on the
 
37
         * server side right now either, I went the simple route to avoid
 
38
         * needless noise.
 
39
         */
 
40
        f func(c container, r *http.Request) *devLxdResponse
 
41
}
 
42
 
 
43
var configGet = devLxdHandler{"/1.0/config", func(c container, r *http.Request) *devLxdResponse {
 
44
        filtered := []string{}
 
45
        for k, _ := range c.ExpandedConfig() {
 
46
                if strings.HasPrefix(k, "user.") {
 
47
                        filtered = append(filtered, fmt.Sprintf("/1.0/config/%s", k))
 
48
                }
 
49
        }
 
50
        return okResponse(filtered, "json")
 
51
}}
 
52
 
 
53
var configKeyGet = devLxdHandler{"/1.0/config/{key}", func(c container, r *http.Request) *devLxdResponse {
 
54
        key := mux.Vars(r)["key"]
 
55
        if !strings.HasPrefix(key, "user.") {
 
56
                return &devLxdResponse{"not authorized", http.StatusForbidden, "raw"}
 
57
        }
 
58
 
 
59
        value, ok := c.ExpandedConfig()[key]
 
60
        if !ok {
 
61
                return &devLxdResponse{"not found", http.StatusNotFound, "raw"}
 
62
        }
 
63
 
 
64
        return okResponse(value, "raw")
 
65
}}
 
66
 
 
67
var metadataGet = devLxdHandler{"/1.0/meta-data", func(c container, r *http.Request) *devLxdResponse {
 
68
        value := c.ExpandedConfig()["user.meta-data"]
 
69
        return okResponse(fmt.Sprintf("#cloud-config\ninstance-id: %s\nlocal-hostname: %s\n%s", c.Name(), c.Name(), value), "raw")
 
70
}}
 
71
 
 
72
var handlers = []devLxdHandler{
 
73
        devLxdHandler{"/", func(c container, r *http.Request) *devLxdResponse {
 
74
                return okResponse([]string{"/1.0"}, "json")
 
75
        }},
 
76
        devLxdHandler{"/1.0", func(c container, r *http.Request) *devLxdResponse {
 
77
                return okResponse(shared.Jmap{"api_version": shared.APIVersion}, "json")
 
78
        }},
 
79
        configGet,
 
80
        configKeyGet,
 
81
        metadataGet,
 
82
        /* TODO: events */
 
83
}
 
84
 
 
85
func hoistReq(f func(container, *http.Request) *devLxdResponse, d *Daemon) func(http.ResponseWriter, *http.Request) {
 
86
        return func(w http.ResponseWriter, r *http.Request) {
 
87
                conn := extractUnderlyingConn(w)
 
88
                cred, ok := pidMapper.m[conn]
 
89
                if !ok {
 
90
                        http.Error(w, pidNotInContainerErr.Error(), 500)
 
91
                        return
 
92
                }
 
93
 
 
94
                c, err := findContainerForPid(cred.pid, d)
 
95
                if err != nil {
 
96
                        http.Error(w, err.Error(), 500)
 
97
                        return
 
98
                }
 
99
 
 
100
                // Access control
 
101
                rootUid := int64(0)
 
102
 
 
103
                idmapset, err := c.LastIdmapSet()
 
104
                if err == nil && idmapset != nil {
 
105
                        uid, _ := idmapset.ShiftIntoNs(0, 0)
 
106
                        rootUid = int64(uid)
 
107
                }
 
108
 
 
109
                if rootUid != cred.uid {
 
110
                        http.Error(w, "Access denied for non-root user", 401)
 
111
                        return
 
112
                }
 
113
 
 
114
                resp := f(c, r)
 
115
                if resp.code != http.StatusOK {
 
116
                        http.Error(w, fmt.Sprintf("%s", resp.content), resp.code)
 
117
                } else if resp.ctype == "json" {
 
118
                        w.Header().Set("Content-Type", "application/json")
 
119
                        WriteJSON(w, resp.content)
 
120
                } else {
 
121
                        w.Header().Set("Content-Type", "application/octet-stream")
 
122
                        fmt.Fprintf(w, resp.content.(string))
 
123
                }
 
124
        }
 
125
}
 
126
 
 
127
func createAndBindDevLxd() (*net.UnixListener, error) {
 
128
        sockFile := path.Join(shared.VarPath("devlxd"), "sock")
 
129
 
 
130
        /*
 
131
         * If this socket exists, that means a previous lxd died and didn't
 
132
         * clean up after itself. We assume that the LXD is actually dead if we
 
133
         * get this far, since StartDaemon() tries to connect to the actual lxd
 
134
         * socket to make sure that it is actually dead. So, it is safe to
 
135
         * remove it here without any checks.
 
136
         *
 
137
         * Also, it would be nice to SO_REUSEADDR here so we don't have to
 
138
         * delete the socket, but we can't:
 
139
         *   http://stackoverflow.com/questions/15716302/so-reuseaddr-and-af-unix
 
140
         *
 
141
         * Note that this will force clients to reconnect when LXD is restarted.
 
142
         */
 
143
        if err := os.Remove(sockFile); err != nil && !os.IsNotExist(err) {
 
144
                return nil, err
 
145
        }
 
146
 
 
147
        unixAddr, err := net.ResolveUnixAddr("unix", sockFile)
 
148
        if err != nil {
 
149
                return nil, err
 
150
        }
 
151
 
 
152
        unixl, err := net.ListenUnix("unix", unixAddr)
 
153
        if err != nil {
 
154
                return nil, err
 
155
        }
 
156
 
 
157
        if err := os.Chmod(sockFile, 0666); err != nil {
 
158
                return nil, err
 
159
        }
 
160
 
 
161
        return unixl, nil
 
162
}
 
163
 
 
164
func devLxdServer(d *Daemon) *http.Server {
 
165
        m := mux.NewRouter()
 
166
 
 
167
        for _, handler := range handlers {
 
168
                m.HandleFunc(handler.path, hoistReq(handler.f, d))
 
169
        }
 
170
 
 
171
        return &http.Server{
 
172
                Handler:   m,
 
173
                ConnState: pidMapper.ConnStateHandler,
 
174
        }
 
175
}
 
176
 
 
177
/*
 
178
 * Everything below here is the guts of the unix socket bits. Unfortunately,
 
179
 * golang's API does not make this easy. What happens is:
 
180
 *
 
181
 * 1. We install a ConnState listener on the http.Server, which does the
 
182
 *    initial unix socket credential exchange. When we get a connection started
 
183
 *    event, we use SO_PEERCRED to extract the creds for the socket.
 
184
 *
 
185
 * 2. We store a map from the connection pointer to the pid for that
 
186
 *    connection, so that once the HTTP negotiation occurrs and we get a
 
187
 *    ResponseWriter, we know (because we negotiated on the first byte) which
 
188
 *    pid the connection belogs to.
 
189
 *
 
190
 * 3. Regular HTTP negotiation and dispatch occurs via net/http.
 
191
 *
 
192
 * 4. When rendering the response via ResponseWriter, we match its underlying
 
193
 *    connection against what we stored in step (2) to figure out which container
 
194
 *    it came from.
 
195
 */
 
196
 
 
197
/*
 
198
 * We keep this in a global so that we can reference it from the server and
 
199
 * from our http handlers, since there appears to be no way to pass information
 
200
 * around here.
 
201
 */
 
202
var pidMapper = ConnPidMapper{m: map[*net.UnixConn]*ucred{}}
 
203
 
 
204
type ucred struct {
 
205
        pid int32
 
206
        uid int64
 
207
        gid int64
 
208
}
 
209
 
 
210
type ConnPidMapper struct {
 
211
        m map[*net.UnixConn]*ucred
 
212
}
 
213
 
 
214
func (m *ConnPidMapper) ConnStateHandler(conn net.Conn, state http.ConnState) {
 
215
        unixConn := conn.(*net.UnixConn)
 
216
        switch state {
 
217
        case http.StateNew:
 
218
                cred, err := getCred(unixConn)
 
219
                if err != nil {
 
220
                        shared.Debugf("Error getting ucred for conn %s", err)
 
221
                } else {
 
222
                        m.m[unixConn] = cred
 
223
                }
 
224
        case http.StateActive:
 
225
                return
 
226
        case http.StateIdle:
 
227
                return
 
228
        case http.StateHijacked:
 
229
                /*
 
230
                 * The "Hijacked" state indicates that the connection has been
 
231
                 * taken over from net/http. This is useful for things like
 
232
                 * developing websocket libraries, who want to upgrade the
 
233
                 * connection to a websocket one, and not use net/http any
 
234
                 * more. Whatever the case, we want to forget about it since we
 
235
                 * won't see it either.
 
236
                 */
 
237
                delete(m.m, unixConn)
 
238
        case http.StateClosed:
 
239
                delete(m.m, unixConn)
 
240
        default:
 
241
                shared.Debugf("Unknown state for connection %s", state)
 
242
        }
 
243
}
 
244
 
 
245
/*
 
246
 * I also don't see that golang exports an API to get at the underlying FD, but
 
247
 * we need it to get at SO_PEERCRED, so let's grab it.
 
248
 */
 
249
func extractUnderlyingFd(unixConnPtr *net.UnixConn) int {
 
250
        conn := reflect.Indirect(reflect.ValueOf(unixConnPtr))
 
251
        netFdPtr := conn.FieldByName("fd")
 
252
        netFd := reflect.Indirect(netFdPtr)
 
253
        fd := netFd.FieldByName("sysfd")
 
254
        return int(fd.Int())
 
255
}
 
256
 
 
257
func getCred(conn *net.UnixConn) (*ucred, error) {
 
258
        fd := extractUnderlyingFd(conn)
 
259
 
 
260
        uid, gid, pid, err := getUcred(fd)
 
261
        if err != nil {
 
262
                return nil, err
 
263
        }
 
264
 
 
265
        return &ucred{pid, int64(uid), int64(gid)}, nil
 
266
}
 
267
 
 
268
/*
 
269
 * As near as I can tell, there is no nice way of extracting an underlying
 
270
 * net.Conn (or in our case, net.UnixConn) from an http.Request or
 
271
 * ResponseWriter without hijacking it [1]. Since we want to send and recieve
 
272
 * unix creds to figure out which container this request came from, we need to
 
273
 * do this.
 
274
 *
 
275
 * [1]: https://groups.google.com/forum/#!topic/golang-nuts/_FWdFXJa6QA
 
276
 */
 
277
func extractUnderlyingConn(w http.ResponseWriter) *net.UnixConn {
 
278
        v := reflect.Indirect(reflect.ValueOf(w))
 
279
        connPtr := v.FieldByName("conn")
 
280
        conn := reflect.Indirect(connPtr)
 
281
        rwc := conn.FieldByName("rwc")
 
282
 
 
283
        netConnPtr := (*net.Conn)(unsafe.Pointer(rwc.UnsafeAddr()))
 
284
        unixConnPtr := (*netConnPtr).(*net.UnixConn)
 
285
 
 
286
        return unixConnPtr
 
287
}
 
288
 
 
289
var pidNotInContainerErr = fmt.Errorf("pid not in container?")
 
290
 
 
291
func findContainerForPid(pid int32, d *Daemon) (container, error) {
 
292
        /*
 
293
         * Try and figure out which container a pid is in. There is probably a
 
294
         * better way to do this. Based on rharper's initial performance
 
295
         * metrics, looping over every container and calling newLxdContainer is
 
296
         * expensive, so I wanted to avoid that if possible, so this happens in
 
297
         * a two step process:
 
298
         *
 
299
         * 1. Walk up the process tree until you see something that looks like
 
300
         *    an lxc monitor process and extract its name from there.
 
301
         *
 
302
         * 2. If this fails, it may be that someone did an `lxc exec foo bash`,
 
303
         *    so the process isn't actually a decendant of the container's
 
304
         *    init. In this case we just look through all the containers until
 
305
         *    we find an init with a matching pid namespace. This is probably
 
306
         *    uncommon, so hopefully the slowness won't hurt us.
 
307
         */
 
308
 
 
309
        origpid := pid
 
310
 
 
311
        for pid > 1 {
 
312
                cmdline, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
 
313
                if err != nil {
 
314
                        return nil, err
 
315
                }
 
316
 
 
317
                if strings.HasPrefix(string(cmdline), "[lxc monitor]") {
 
318
                        // container names can't have spaces
 
319
                        parts := strings.Split(string(cmdline), " ")
 
320
                        name := strings.TrimSuffix(parts[len(parts)-1], "\x00")
 
321
 
 
322
                        return containerLoadByName(d, name)
 
323
                }
 
324
 
 
325
                status, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/status", pid))
 
326
                if err != nil {
 
327
                        return nil, err
 
328
                }
 
329
 
 
330
                re := regexp.MustCompile("PPid:\\s*([0-9]*)")
 
331
                for _, line := range strings.Split(string(status), "\n") {
 
332
                        m := re.FindStringSubmatch(line)
 
333
                        if m != nil && len(m) > 1 {
 
334
                                result, err := strconv.Atoi(m[1])
 
335
                                if err != nil {
 
336
                                        return nil, err
 
337
                                }
 
338
 
 
339
                                pid = int32(result)
 
340
                                break
 
341
                        }
 
342
                }
 
343
        }
 
344
 
 
345
        origPidNs, err := os.Readlink(fmt.Sprintf("/proc/%d/ns/pid", origpid))
 
346
        if err != nil {
 
347
                return nil, err
 
348
        }
 
349
 
 
350
        containers, err := dbContainersList(d.db, cTypeRegular)
 
351
        if err != nil {
 
352
                return nil, err
 
353
        }
 
354
 
 
355
        for _, container := range containers {
 
356
                c, err := containerLoadByName(d, container)
 
357
                if err != nil {
 
358
                        return nil, err
 
359
                }
 
360
 
 
361
                if !c.IsRunning() {
 
362
                        continue
 
363
                }
 
364
 
 
365
                initpid := c.InitPID()
 
366
                pidNs, err := os.Readlink(fmt.Sprintf("/proc/%d/ns/pid", initpid))
 
367
                if err != nil {
 
368
                        return nil, err
 
369
                }
 
370
 
 
371
                if origPidNs == pidNs {
 
372
                        return c, nil
 
373
                }
 
374
        }
 
375
 
 
376
        return nil, pidNotInContainerErr
 
377
}