~nskaggs/+junk/xenial-test

« back to all changes in this revision

Viewing changes to src/github.com/juju/juju/cmd/juju/commands/ssh_common.go

  • Committer: Nicholas Skaggs
  • Date: 2016-10-24 20:56:05 UTC
  • Revision ID: nicholas.skaggs@canonical.com-20161024205605-z8lta0uvuhtxwzwl
Initi with beta15

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
// Copyright 2016 Canonical Ltd.
 
2
// Licensed under the AGPLv3, see LICENCE file for details.
 
3
 
 
4
package commands
 
5
 
 
6
import (
 
7
        "bufio"
 
8
        "fmt"
 
9
        "io"
 
10
        "io/ioutil"
 
11
        "net"
 
12
        "os"
 
13
        "os/exec"
 
14
        "strings"
 
15
        "time"
 
16
 
 
17
        "github.com/juju/errors"
 
18
        "github.com/juju/utils"
 
19
        "github.com/juju/utils/set"
 
20
        "github.com/juju/utils/ssh"
 
21
        "gopkg.in/juju/names.v2"
 
22
        "launchpad.net/gnuflag"
 
23
 
 
24
        "github.com/juju/juju/api/sshclient"
 
25
        "github.com/juju/juju/cmd/modelcmd"
 
26
)
 
27
 
 
28
// SSHCommon implements functionality shared by sshCommand, SCPCommand
 
29
// and DebugHooksCommand.
 
30
type SSHCommon struct {
 
31
        modelcmd.ModelCommandBase
 
32
        proxy           bool
 
33
        pty             bool
 
34
        noHostKeyChecks bool
 
35
        Target          string
 
36
        Args            []string
 
37
        apiClient       sshAPIClient
 
38
        apiAddr         string
 
39
        knownHostsPath  string
 
40
}
 
41
 
 
42
type sshAPIClient interface {
 
43
        PublicAddress(target string) (string, error)
 
44
        PrivateAddress(target string) (string, error)
 
45
        PublicKeys(target string) ([]string, error)
 
46
        Proxy() (bool, error)
 
47
        Close() error
 
48
}
 
49
 
 
50
type resolvedTarget struct {
 
51
        user   string
 
52
        entity string
 
53
        host   string
 
54
}
 
55
 
 
56
func (t *resolvedTarget) userHost() string {
 
57
        if t.user == "" {
 
58
                return t.host
 
59
        }
 
60
        return t.user + "@" + t.host
 
61
}
 
62
 
 
63
func (t *resolvedTarget) isAgent() bool {
 
64
        return targetIsAgent(t.entity)
 
65
}
 
66
 
 
67
// attemptStarter is an interface corresponding to utils.AttemptStrategy
 
68
//
 
69
// TODO(katco): 2016-08-09: lp:1611427
 
70
type attemptStarter interface {
 
71
        Start() attempt
 
72
}
 
73
 
 
74
type attempt interface {
 
75
        Next() bool
 
76
}
 
77
 
 
78
// TODO(katco): 2016-08-09: lp:1611427
 
79
type attemptStrategy utils.AttemptStrategy
 
80
 
 
81
func (s attemptStrategy) Start() attempt {
 
82
        // TODO(katco): 2016-08-09: lp:1611427
 
83
        return utils.AttemptStrategy(s).Start()
 
84
}
 
85
 
 
86
var sshHostFromTargetAttemptStrategy attemptStarter = attemptStrategy{
 
87
        Total: 5 * time.Second,
 
88
        Delay: 500 * time.Millisecond,
 
89
}
 
90
 
 
91
func (c *SSHCommon) SetFlags(f *gnuflag.FlagSet) {
 
92
        f.BoolVar(&c.proxy, "proxy", false, "Proxy through the API server")
 
93
        f.BoolVar(&c.pty, "pty", true, "Enable pseudo-tty allocation")
 
94
        f.BoolVar(&c.noHostKeyChecks, "no-host-key-checks", false, "Skip host key checking (INSECURE)")
 
95
}
 
96
 
 
97
// initRun initializes the API connection if required, and determines
 
98
// if SSH proxying is required. It must be called at the top of the
 
99
// command's Run method.
 
100
//
 
101
// The apiClient, apiAddr and proxy fields are initialized after this
 
102
// call.
 
103
func (c *SSHCommon) initRun() error {
 
104
        if err := c.ensureAPIClient(); err != nil {
 
105
                return errors.Trace(err)
 
106
        }
 
107
        if proxy, err := c.proxySSH(); err != nil {
 
108
                return errors.Trace(err)
 
109
        } else {
 
110
                c.proxy = proxy
 
111
        }
 
112
        return nil
 
113
}
 
114
 
 
115
// cleanupRun removes the temporary SSH known_hosts file (if one was
 
116
// created) and closes the API connection. It must be called at the
 
117
// end of the command's Run (i.e. as a defer).
 
118
func (c *SSHCommon) cleanupRun() {
 
119
        if c.knownHostsPath != "" {
 
120
                os.Remove(c.knownHostsPath)
 
121
                c.knownHostsPath = ""
 
122
        }
 
123
        if c.apiClient != nil {
 
124
                c.apiClient.Close()
 
125
                c.apiClient = nil
 
126
        }
 
127
}
 
128
 
 
129
// getSSHOptions configures SSH options based on command line
 
130
// arguments and the SSH targets specified.
 
131
func (c *SSHCommon) getSSHOptions(enablePty bool, targets ...*resolvedTarget) (*ssh.Options, error) {
 
132
        var options ssh.Options
 
133
 
 
134
        if c.noHostKeyChecks {
 
135
                options.SetStrictHostKeyChecking(ssh.StrictHostChecksNo)
 
136
                options.SetKnownHostsFile("/dev/null")
 
137
        } else {
 
138
                knownHostsPath, err := c.generateKnownHosts(targets)
 
139
                if err != nil {
 
140
                        return nil, errors.Trace(err)
 
141
                }
 
142
 
 
143
                // There might not be a custom known_hosts file if the SSH
 
144
                // targets are specified using arbitrary hostnames or
 
145
                // addresses. In this case, the user's personal known_hosts
 
146
                // file is used.
 
147
 
 
148
                if knownHostsPath != "" {
 
149
                        // When a known_hosts file has been generated, enforce
 
150
                        // strict host key checking.
 
151
                        options.SetStrictHostKeyChecking(ssh.StrictHostChecksYes)
 
152
                        options.SetKnownHostsFile(knownHostsPath)
 
153
                } else {
 
154
                        // If the user's personal known_hosts is used, also use
 
155
                        // the user's personal StrictHostKeyChecking preferences.
 
156
                        options.SetStrictHostKeyChecking(ssh.StrictHostChecksUnset)
 
157
                }
 
158
        }
 
159
 
 
160
        if enablePty {
 
161
                options.EnablePTY()
 
162
        }
 
163
 
 
164
        if c.proxy {
 
165
                if err := c.setProxyCommand(&options); err != nil {
 
166
                        return nil, err
 
167
                }
 
168
        }
 
169
 
 
170
        return &options, nil
 
171
}
 
172
 
 
173
// generateKnownHosts takes the provided targets, retrieves the SSH
 
174
// public host keys for them and generates a temporary known_hosts
 
175
// file for them.
 
176
func (c *SSHCommon) generateKnownHosts(targets []*resolvedTarget) (string, error) {
 
177
        knownHosts := newKnownHostsBuilder()
 
178
        agentCount := 0
 
179
        nonAgentCount := 0
 
180
        for _, target := range targets {
 
181
                if target.isAgent() {
 
182
                        agentCount++
 
183
                        keys, err := c.apiClient.PublicKeys(target.entity)
 
184
                        if err != nil {
 
185
                                return "", errors.Annotatef(err, "retrieving SSH host keys for %q", target.entity)
 
186
                        }
 
187
                        knownHosts.add(target.host, keys)
 
188
                } else {
 
189
                        nonAgentCount++
 
190
                }
 
191
        }
 
192
 
 
193
        if agentCount > 0 && nonAgentCount > 0 {
 
194
                return "", errors.New("can't determine host keys for all targets: consider --no-host-key-checks")
 
195
        }
 
196
 
 
197
        if knownHosts.size() == 0 {
 
198
                // No public keys to write so exit early.
 
199
                return "", nil
 
200
        }
 
201
 
 
202
        f, err := ioutil.TempFile("", "ssh_known_hosts")
 
203
        if err != nil {
 
204
                return "", errors.Annotate(err, "creating known hosts file")
 
205
        }
 
206
        defer f.Close()
 
207
        c.knownHostsPath = f.Name() // Record for later deletion
 
208
        if knownHosts.write(f); err != nil {
 
209
                return "", errors.Trace(err)
 
210
        }
 
211
        return c.knownHostsPath, nil
 
212
}
 
213
 
 
214
// proxySSH returns false if both c.proxy and the proxy-ssh model
 
215
// configuration are false -- otherwise it returns true.
 
216
func (c *SSHCommon) proxySSH() (bool, error) {
 
217
        if c.proxy {
 
218
                // No need to check the API if user explictly requested
 
219
                // proxying.
 
220
                return true, nil
 
221
        }
 
222
        proxy, err := c.apiClient.Proxy()
 
223
        if err != nil {
 
224
                return false, errors.Trace(err)
 
225
        }
 
226
        logger.Debugf("proxy-ssh is %v", proxy)
 
227
        return proxy, nil
 
228
}
 
229
 
 
230
// setProxyCommand sets the proxy command option.
 
231
func (c *SSHCommon) setProxyCommand(options *ssh.Options) error {
 
232
        apiServerHost, _, err := net.SplitHostPort(c.apiAddr)
 
233
        if err != nil {
 
234
                return fmt.Errorf("failed to get proxy address: %v", err)
 
235
        }
 
236
        juju, err := getJujuExecutable()
 
237
        if err != nil {
 
238
                return fmt.Errorf("failed to get juju executable path: %v", err)
 
239
        }
 
240
 
 
241
        // TODO(mjs) 2016-05-09 LP #1579592 - It would be good to check the
 
242
        // host key of the controller machine being used for proxying
 
243
        // here. This isn't too serious as all traffic passing through the
 
244
        // controller host is encrypted and the host key of the ultimate
 
245
        // target host is verified but it would still be better to perform
 
246
        // this extra level of checking.
 
247
        options.SetProxyCommand(
 
248
                juju, "ssh",
 
249
                "--proxy=false",
 
250
                "--no-host-key-checks",
 
251
                "--pty=false",
 
252
                "ubuntu@"+apiServerHost,
 
253
                "-q",
 
254
                "nc %h %p",
 
255
        )
 
256
        return nil
 
257
}
 
258
 
 
259
func (c *SSHCommon) ensureAPIClient() error {
 
260
        if c.apiClient != nil {
 
261
                return nil
 
262
        }
 
263
        return errors.Trace(c.initAPIClient())
 
264
}
 
265
 
 
266
// initAPIClient initialises the API connection.
 
267
func (c *SSHCommon) initAPIClient() error {
 
268
        conn, err := c.NewAPIRoot()
 
269
        if err != nil {
 
270
                return errors.Trace(err)
 
271
        }
 
272
        c.apiClient = sshclient.NewFacade(conn)
 
273
        c.apiAddr = conn.Addr()
 
274
        return nil
 
275
}
 
276
 
 
277
func (c *SSHCommon) resolveTarget(target string) (*resolvedTarget, error) {
 
278
        out := new(resolvedTarget)
 
279
        out.user, out.entity = splitUserTarget(target)
 
280
 
 
281
        // If the target is neither a machine nor a unit assume it's a
 
282
        // hostname and try it directly.
 
283
        if !targetIsAgent(out.entity) {
 
284
                out.host = out.entity
 
285
                return out, nil
 
286
        }
 
287
 
 
288
        if out.user == "" {
 
289
                out.user = "ubuntu"
 
290
        }
 
291
 
 
292
        // A target may not initially have an address (e.g. the
 
293
        // address updater hasn't yet run), so we must do this in
 
294
        // a loop.
 
295
        var err error
 
296
        for a := sshHostFromTargetAttemptStrategy.Start(); a.Next(); {
 
297
                if c.proxy {
 
298
                        out.host, err = c.apiClient.PrivateAddress(out.entity)
 
299
                } else {
 
300
                        out.host, err = c.apiClient.PublicAddress(out.entity)
 
301
                }
 
302
                if err == nil {
 
303
                        return out, nil
 
304
                }
 
305
        }
 
306
        return nil, err
 
307
}
 
308
 
 
309
// AllowInterspersedFlags for ssh/scp is set to false so that
 
310
// flags after the unit name are passed through to ssh, for eg.
 
311
// `juju ssh -v application-name/0 uname -a`.
 
312
func (c *SSHCommon) AllowInterspersedFlags() bool {
 
313
        return false
 
314
}
 
315
 
 
316
// getJujuExecutable returns the path to the juju
 
317
// executable, or an error if it could not be found.
 
318
var getJujuExecutable = func() (string, error) {
 
319
        return exec.LookPath(os.Args[0])
 
320
}
 
321
 
 
322
func targetIsAgent(target string) bool {
 
323
        return names.IsValidMachine(target) || names.IsValidUnit(target)
 
324
}
 
325
 
 
326
func splitUserTarget(target string) (string, string) {
 
327
        if i := strings.IndexRune(target, '@'); i != -1 {
 
328
                return target[:i], target[i+1:]
 
329
        }
 
330
        return "", target
 
331
}
 
332
 
 
333
func newKnownHostsBuilder() *knownHostsBuilder {
 
334
        return &knownHostsBuilder{
 
335
                seen: set.NewStrings(),
 
336
        }
 
337
}
 
338
 
 
339
// knownHostsBuilder supports the construction of a SSH known_hosts file.
 
340
type knownHostsBuilder struct {
 
341
        lines []string
 
342
        seen  set.Strings
 
343
}
 
344
 
 
345
func (b *knownHostsBuilder) add(host string, keys []string) {
 
346
        if b.seen.Contains(host) {
 
347
                return
 
348
        }
 
349
        b.seen.Add(host)
 
350
        for _, key := range keys {
 
351
                b.lines = append(b.lines, host+" "+key+"\n")
 
352
        }
 
353
}
 
354
 
 
355
func (b *knownHostsBuilder) write(w io.Writer) error {
 
356
        bufw := bufio.NewWriter(w)
 
357
        for _, line := range b.lines {
 
358
                _, err := bufw.WriteString(line)
 
359
                if err != nil {
 
360
                        return errors.Annotate(err, "writing known hosts file")
 
361
                }
 
362
        }
 
363
        bufw.Flush()
 
364
        return nil
 
365
}
 
366
 
 
367
func (b *knownHostsBuilder) size() int {
 
368
        return len(b.lines)
 
369
}