1
// Copyright 2016 Canonical Ltd.
2
// Licensed under the AGPLv3, see LICENCE file for details.
17
"github.com/juju/errors"
18
"github.com/juju/utils"
19
"github.com/juju/utils/set"
20
"github.com/juju/utils/ssh"
21
"gopkg.in/juju/names.v2"
22
"launchpad.net/gnuflag"
24
"github.com/juju/juju/api/sshclient"
25
"github.com/juju/juju/cmd/modelcmd"
28
// SSHCommon implements functionality shared by sshCommand, SCPCommand
29
// and DebugHooksCommand.
30
type SSHCommon struct {
31
modelcmd.ModelCommandBase
37
apiClient sshAPIClient
42
type sshAPIClient interface {
43
PublicAddress(target string) (string, error)
44
PrivateAddress(target string) (string, error)
45
PublicKeys(target string) ([]string, error)
50
type resolvedTarget struct {
56
func (t *resolvedTarget) userHost() string {
60
return t.user + "@" + t.host
63
func (t *resolvedTarget) isAgent() bool {
64
return targetIsAgent(t.entity)
67
// attemptStarter is an interface corresponding to utils.AttemptStrategy
69
// TODO(katco): 2016-08-09: lp:1611427
70
type attemptStarter interface {
74
type attempt interface {
78
// TODO(katco): 2016-08-09: lp:1611427
79
type attemptStrategy utils.AttemptStrategy
81
func (s attemptStrategy) Start() attempt {
82
// TODO(katco): 2016-08-09: lp:1611427
83
return utils.AttemptStrategy(s).Start()
86
var sshHostFromTargetAttemptStrategy attemptStarter = attemptStrategy{
87
Total: 5 * time.Second,
88
Delay: 500 * time.Millisecond,
91
func (c *SSHCommon) SetFlags(f *gnuflag.FlagSet) {
92
f.BoolVar(&c.proxy, "proxy", false, "Proxy through the API server")
93
f.BoolVar(&c.pty, "pty", true, "Enable pseudo-tty allocation")
94
f.BoolVar(&c.noHostKeyChecks, "no-host-key-checks", false, "Skip host key checking (INSECURE)")
97
// initRun initializes the API connection if required, and determines
98
// if SSH proxying is required. It must be called at the top of the
99
// command's Run method.
101
// The apiClient, apiAddr and proxy fields are initialized after this
103
func (c *SSHCommon) initRun() error {
104
if err := c.ensureAPIClient(); err != nil {
105
return errors.Trace(err)
107
if proxy, err := c.proxySSH(); err != nil {
108
return errors.Trace(err)
115
// cleanupRun removes the temporary SSH known_hosts file (if one was
116
// created) and closes the API connection. It must be called at the
117
// end of the command's Run (i.e. as a defer).
118
func (c *SSHCommon) cleanupRun() {
119
if c.knownHostsPath != "" {
120
os.Remove(c.knownHostsPath)
121
c.knownHostsPath = ""
123
if c.apiClient != nil {
129
// getSSHOptions configures SSH options based on command line
130
// arguments and the SSH targets specified.
131
func (c *SSHCommon) getSSHOptions(enablePty bool, targets ...*resolvedTarget) (*ssh.Options, error) {
132
var options ssh.Options
134
if c.noHostKeyChecks {
135
options.SetStrictHostKeyChecking(ssh.StrictHostChecksNo)
136
options.SetKnownHostsFile("/dev/null")
138
knownHostsPath, err := c.generateKnownHosts(targets)
140
return nil, errors.Trace(err)
143
// There might not be a custom known_hosts file if the SSH
144
// targets are specified using arbitrary hostnames or
145
// addresses. In this case, the user's personal known_hosts
148
if knownHostsPath != "" {
149
// When a known_hosts file has been generated, enforce
150
// strict host key checking.
151
options.SetStrictHostKeyChecking(ssh.StrictHostChecksYes)
152
options.SetKnownHostsFile(knownHostsPath)
154
// If the user's personal known_hosts is used, also use
155
// the user's personal StrictHostKeyChecking preferences.
156
options.SetStrictHostKeyChecking(ssh.StrictHostChecksUnset)
165
if err := c.setProxyCommand(&options); err != nil {
173
// generateKnownHosts takes the provided targets, retrieves the SSH
174
// public host keys for them and generates a temporary known_hosts
176
func (c *SSHCommon) generateKnownHosts(targets []*resolvedTarget) (string, error) {
177
knownHosts := newKnownHostsBuilder()
180
for _, target := range targets {
181
if target.isAgent() {
183
keys, err := c.apiClient.PublicKeys(target.entity)
185
return "", errors.Annotatef(err, "retrieving SSH host keys for %q", target.entity)
187
knownHosts.add(target.host, keys)
193
if agentCount > 0 && nonAgentCount > 0 {
194
return "", errors.New("can't determine host keys for all targets: consider --no-host-key-checks")
197
if knownHosts.size() == 0 {
198
// No public keys to write so exit early.
202
f, err := ioutil.TempFile("", "ssh_known_hosts")
204
return "", errors.Annotate(err, "creating known hosts file")
207
c.knownHostsPath = f.Name() // Record for later deletion
208
if knownHosts.write(f); err != nil {
209
return "", errors.Trace(err)
211
return c.knownHostsPath, nil
214
// proxySSH returns false if both c.proxy and the proxy-ssh model
215
// configuration are false -- otherwise it returns true.
216
func (c *SSHCommon) proxySSH() (bool, error) {
218
// No need to check the API if user explictly requested
222
proxy, err := c.apiClient.Proxy()
224
return false, errors.Trace(err)
226
logger.Debugf("proxy-ssh is %v", proxy)
230
// setProxyCommand sets the proxy command option.
231
func (c *SSHCommon) setProxyCommand(options *ssh.Options) error {
232
apiServerHost, _, err := net.SplitHostPort(c.apiAddr)
234
return fmt.Errorf("failed to get proxy address: %v", err)
236
juju, err := getJujuExecutable()
238
return fmt.Errorf("failed to get juju executable path: %v", err)
241
// TODO(mjs) 2016-05-09 LP #1579592 - It would be good to check the
242
// host key of the controller machine being used for proxying
243
// here. This isn't too serious as all traffic passing through the
244
// controller host is encrypted and the host key of the ultimate
245
// target host is verified but it would still be better to perform
246
// this extra level of checking.
247
options.SetProxyCommand(
250
"--no-host-key-checks",
252
"ubuntu@"+apiServerHost,
259
func (c *SSHCommon) ensureAPIClient() error {
260
if c.apiClient != nil {
263
return errors.Trace(c.initAPIClient())
266
// initAPIClient initialises the API connection.
267
func (c *SSHCommon) initAPIClient() error {
268
conn, err := c.NewAPIRoot()
270
return errors.Trace(err)
272
c.apiClient = sshclient.NewFacade(conn)
273
c.apiAddr = conn.Addr()
277
func (c *SSHCommon) resolveTarget(target string) (*resolvedTarget, error) {
278
out := new(resolvedTarget)
279
out.user, out.entity = splitUserTarget(target)
281
// If the target is neither a machine nor a unit assume it's a
282
// hostname and try it directly.
283
if !targetIsAgent(out.entity) {
284
out.host = out.entity
292
// A target may not initially have an address (e.g. the
293
// address updater hasn't yet run), so we must do this in
296
for a := sshHostFromTargetAttemptStrategy.Start(); a.Next(); {
298
out.host, err = c.apiClient.PrivateAddress(out.entity)
300
out.host, err = c.apiClient.PublicAddress(out.entity)
309
// AllowInterspersedFlags for ssh/scp is set to false so that
310
// flags after the unit name are passed through to ssh, for eg.
311
// `juju ssh -v application-name/0 uname -a`.
312
func (c *SSHCommon) AllowInterspersedFlags() bool {
316
// getJujuExecutable returns the path to the juju
317
// executable, or an error if it could not be found.
318
var getJujuExecutable = func() (string, error) {
319
return exec.LookPath(os.Args[0])
322
func targetIsAgent(target string) bool {
323
return names.IsValidMachine(target) || names.IsValidUnit(target)
326
func splitUserTarget(target string) (string, string) {
327
if i := strings.IndexRune(target, '@'); i != -1 {
328
return target[:i], target[i+1:]
333
func newKnownHostsBuilder() *knownHostsBuilder {
334
return &knownHostsBuilder{
335
seen: set.NewStrings(),
339
// knownHostsBuilder supports the construction of a SSH known_hosts file.
340
type knownHostsBuilder struct {
345
func (b *knownHostsBuilder) add(host string, keys []string) {
346
if b.seen.Contains(host) {
350
for _, key := range keys {
351
b.lines = append(b.lines, host+" "+key+"\n")
355
func (b *knownHostsBuilder) write(w io.Writer) error {
356
bufw := bufio.NewWriter(w)
357
for _, line := range b.lines {
358
_, err := bufw.WriteString(line)
360
return errors.Annotate(err, "writing known hosts file")
367
func (b *knownHostsBuilder) size() int {