~rogpeppe/+junk/mgo-tagged-log-messages

« back to all changes in this revision

Viewing changes to sasl/sasl.go

  • Committer: Roger Peppe
  • Date: 2014-03-14 18:11:33 UTC
  • mfrom: (263.1.8 master)
  • Revision ID: roger.peppe@canonical.com-20140314181133-107ag3xpitk9682u
merge trunk

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
// Package sasl is an implementation detail of the mgo package.
 
2
//
 
3
// This package is not meant to be used by itself.
 
4
//
 
5
package sasl
 
6
 
 
7
// #cgo LDFLAGS: -lsasl2
 
8
//
 
9
// struct sasl_conn {};
 
10
//
 
11
// #include <stdlib.h>
 
12
// #include <sasl/sasl.h>
 
13
//
 
14
// sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password);
 
15
//
 
16
import "C"
 
17
 
 
18
import (
 
19
        "fmt"
 
20
        "strings"
 
21
        "sync"
 
22
        "unsafe"
 
23
)
 
24
 
 
25
type saslStepper interface {
 
26
        Step(serverData []byte) (clientData []byte, done bool, err error)
 
27
        Close()
 
28
}
 
29
 
 
30
type saslSession struct {
 
31
        conn *C.sasl_conn_t
 
32
        step int
 
33
        mech string
 
34
 
 
35
        cstrings  []*C.char
 
36
        callbacks *C.sasl_callback_t
 
37
}
 
38
 
 
39
var initError error
 
40
var initOnce sync.Once
 
41
 
 
42
func initSASL() {
 
43
        rc := C.sasl_client_init(nil)
 
44
        if rc != C.SASL_OK {
 
45
                initError = saslError(rc, nil, "cannot initialize SASL library")
 
46
        }
 
47
}
 
48
 
 
49
func New(username, password, mechanism, service, host string) (saslStepper, error) {
 
50
        initOnce.Do(initSASL)
 
51
        if initError != nil {
 
52
                return nil, initError
 
53
        }
 
54
 
 
55
        ss := &saslSession{mech: mechanism}
 
56
        if service == "" {
 
57
                service = "mongodb"
 
58
        }
 
59
        if i := strings.Index(host, ":"); i >= 0 {
 
60
                host = host[:i]
 
61
        }
 
62
        ss.callbacks = C.mgo_sasl_callbacks(ss.cstr(username), ss.cstr(password))
 
63
        rc := C.sasl_client_new(ss.cstr(service), ss.cstr(host), nil, nil, ss.callbacks, 0, &ss.conn)
 
64
        if rc != C.SASL_OK {
 
65
                ss.Close()
 
66
                return nil, saslError(rc, nil, "cannot create new SASL client")
 
67
        }
 
68
        return ss, nil
 
69
}
 
70
 
 
71
func (ss *saslSession) cstr(s string) *C.char {
 
72
        cstr := C.CString(s)
 
73
        ss.cstrings = append(ss.cstrings, cstr)
 
74
        return cstr
 
75
}
 
76
 
 
77
func (ss *saslSession) Close() {
 
78
        for _, cstr := range ss.cstrings {
 
79
                C.free(unsafe.Pointer(cstr))
 
80
        }
 
81
        ss.cstrings = nil
 
82
 
 
83
        if ss.callbacks != nil {
 
84
                C.free(unsafe.Pointer(ss.callbacks))
 
85
        }
 
86
 
 
87
        // The documentation of SASL dispose makes it clear that this should only
 
88
        // be done when the connection is done, not when the authentication phase
 
89
        // is done, because an encryption layer may have been negotiated.
 
90
        // Even then, we'll do this for now, because it's simpler and prevents
 
91
        // keeping track of this state for every socket. If it breaks, we'll fix it.
 
92
        C.sasl_dispose(&ss.conn)
 
93
}
 
94
 
 
95
func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) {
 
96
        ss.step++
 
97
        if ss.step > 10 {
 
98
                return nil, false, fmt.Errorf("too many SASL steps without authentication")
 
99
        }
 
100
        var cclientData *C.char
 
101
        var cclientDataLen C.uint
 
102
        var rc C.int
 
103
        if ss.step == 1 {
 
104
                var mechanism *C.char // ignored - must match cred
 
105
                rc = C.sasl_client_start(ss.conn, ss.cstr(ss.mech), nil, &cclientData, &cclientDataLen, &mechanism)
 
106
        } else {
 
107
                var cserverData *C.char
 
108
                var cserverDataLen C.uint
 
109
                if len(serverData) > 0 {
 
110
                        cserverData = (*C.char)(unsafe.Pointer(&serverData[0]))
 
111
                        cserverDataLen = C.uint(len(serverData))
 
112
                }
 
113
                rc = C.sasl_client_step(ss.conn, cserverData, cserverDataLen, nil, &cclientData, &cclientDataLen)
 
114
        }
 
115
        if cclientData != nil && cclientDataLen > 0 {
 
116
                clientData = C.GoBytes(unsafe.Pointer(cclientData), C.int(cclientDataLen))
 
117
        }
 
118
        if rc == C.SASL_OK {
 
119
                return clientData, true, nil
 
120
        }
 
121
        if rc == C.SASL_CONTINUE {
 
122
                return clientData, false, nil
 
123
        }
 
124
        return nil, false, saslError(rc, ss.conn, "cannot establish SASL session")
 
125
}
 
126
 
 
127
func saslError(rc C.int, conn *C.sasl_conn_t, msg string) error {
 
128
        var detail string
 
129
        if conn == nil {
 
130
                detail = C.GoString(C.sasl_errstring(rc, nil, nil))
 
131
        } else {
 
132
                detail = C.GoString(C.sasl_errdetail(conn))
 
133
        }
 
134
        return fmt.Errorf(msg + ": " + detail)
 
135
}