forked from Shiloh/githaven
253 lines
6.4 KiB
Go
Vendored
253 lines
6.4 KiB
Go
Vendored
// @author Couchbase <info@couchbase.com>
|
|
// @copyright 2018 Couchbase, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Package scramsha provides implementation of client side SCRAM-SHA
|
|
// via Http according to https://tools.ietf.org/html/rfc7804
|
|
package scramsha
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"github.com/pkg/errors"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
// consts used to parse scramsha response from target
|
|
const (
|
|
WWWAuthenticate = "WWW-Authenticate"
|
|
AuthenticationInfo = "Authentication-Info"
|
|
Authorization = "Authorization"
|
|
DataPrefix = "data="
|
|
SidPrefix = "sid="
|
|
)
|
|
|
|
// Request provides implementation of http request that can be retried
|
|
type Request struct {
|
|
body io.ReadSeeker
|
|
|
|
// Embed an HTTP request directly. This makes a *Request act exactly
|
|
// like an *http.Request so that all meta methods are supported.
|
|
*http.Request
|
|
}
|
|
|
|
type lenReader interface {
|
|
Len() int
|
|
}
|
|
|
|
// NewRequest creates http request that can be retried
|
|
func NewRequest(method, url string, body io.ReadSeeker) (*Request, error) {
|
|
// Wrap the body in a noop ReadCloser if non-nil. This prevents the
|
|
// reader from being closed by the HTTP client.
|
|
var rcBody io.ReadCloser
|
|
if body != nil {
|
|
rcBody = ioutil.NopCloser(body)
|
|
}
|
|
|
|
// Make the request with the noop-closer for the body.
|
|
httpReq, err := http.NewRequest(method, url, rcBody)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check if we can set the Content-Length automatically.
|
|
if lr, ok := body.(lenReader); ok {
|
|
httpReq.ContentLength = int64(lr.Len())
|
|
}
|
|
|
|
return &Request{body, httpReq}, nil
|
|
}
|
|
|
|
func encode(str string) string {
|
|
return base64.StdEncoding.EncodeToString([]byte(str))
|
|
}
|
|
|
|
func decode(str string) (string, error) {
|
|
bytes, err := base64.StdEncoding.DecodeString(str)
|
|
if err != nil {
|
|
return "", errors.Errorf("Cannot base64 decode %s",
|
|
str)
|
|
}
|
|
return string(bytes), err
|
|
}
|
|
|
|
func trimPrefix(s, prefix string) (string, error) {
|
|
l := len(s)
|
|
trimmed := strings.TrimPrefix(s, prefix)
|
|
if l == len(trimmed) {
|
|
return trimmed, errors.Errorf("Prefix %s not found in %s",
|
|
prefix, s)
|
|
}
|
|
return trimmed, nil
|
|
}
|
|
|
|
func drainBody(resp *http.Response) {
|
|
defer resp.Body.Close()
|
|
io.Copy(ioutil.Discard, resp.Body)
|
|
}
|
|
|
|
// DoScramSha performs SCRAM-SHA handshake via Http
|
|
func DoScramSha(req *Request,
|
|
username string,
|
|
password string,
|
|
client *http.Client) (*http.Response, error) {
|
|
|
|
method := "SCRAM-SHA-512"
|
|
s, err := NewScramSha("SCRAM-SHA512")
|
|
if err != nil {
|
|
return nil, errors.Wrap(err,
|
|
"Unable to initialize SCRAM-SHA handler")
|
|
}
|
|
|
|
message, err := s.GetStartRequest(username)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
encodedMessage := method + " " + DataPrefix + encode(message)
|
|
|
|
req.Header.Set(Authorization, encodedMessage)
|
|
|
|
res, err := client.Do(req.Request)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "Problem sending SCRAM-SHA start"+
|
|
"request")
|
|
}
|
|
|
|
if res.StatusCode != http.StatusUnauthorized {
|
|
return res, nil
|
|
}
|
|
|
|
authHeader := res.Header.Get(WWWAuthenticate)
|
|
if authHeader == "" {
|
|
drainBody(res)
|
|
return nil, errors.Errorf("Header %s is not populated in "+
|
|
"SCRAM-SHA start response", WWWAuthenticate)
|
|
}
|
|
|
|
authHeader, err = trimPrefix(authHeader, method+" ")
|
|
if err != nil {
|
|
if strings.HasPrefix(authHeader, "Basic ") {
|
|
// user not found
|
|
return res, nil
|
|
}
|
|
drainBody(res)
|
|
return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
|
|
"start response %s", authHeader)
|
|
}
|
|
|
|
drainBody(res)
|
|
|
|
sid, response, err := parseSidAndData(authHeader)
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
|
|
"start response %s", authHeader)
|
|
}
|
|
|
|
err = s.HandleStartResponse(response)
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "Error parsing SCRAM-SHA start "+
|
|
"response %s", response)
|
|
}
|
|
|
|
message = s.GetFinalRequest(password)
|
|
encodedMessage = method + " " + SidPrefix + sid + "," + DataPrefix +
|
|
encode(message)
|
|
|
|
req.Header.Set(Authorization, encodedMessage)
|
|
|
|
// rewind request body so it can be resent again
|
|
if req.body != nil {
|
|
if _, err = req.body.Seek(0, 0); err != nil {
|
|
return nil, errors.Errorf("Failed to seek body: %v",
|
|
err)
|
|
}
|
|
}
|
|
|
|
res, err = client.Do(req.Request)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "Problem sending SCRAM-SHA final"+
|
|
"request")
|
|
}
|
|
|
|
if res.StatusCode == http.StatusUnauthorized {
|
|
// TODO retrieve and return error
|
|
return res, nil
|
|
}
|
|
|
|
if res.StatusCode >= http.StatusInternalServerError {
|
|
// in this case we cannot expect server to set headers properly
|
|
return res, nil
|
|
}
|
|
|
|
authHeader = res.Header.Get(AuthenticationInfo)
|
|
if authHeader == "" {
|
|
drainBody(res)
|
|
return nil, errors.Errorf("Header %s is not populated in "+
|
|
"SCRAM-SHA final response", AuthenticationInfo)
|
|
}
|
|
|
|
finalSid, response, err := parseSidAndData(authHeader)
|
|
if err != nil {
|
|
drainBody(res)
|
|
return nil, errors.Wrapf(err, "Error while parsing SCRAM-SHA "+
|
|
"final response %s", authHeader)
|
|
}
|
|
|
|
if finalSid != sid {
|
|
drainBody(res)
|
|
return nil, errors.Errorf("Sid %s returned by server "+
|
|
"doesn't match the original sid %s", finalSid, sid)
|
|
}
|
|
|
|
err = s.HandleFinalResponse(response)
|
|
if err != nil {
|
|
drainBody(res)
|
|
return nil, errors.Wrapf(err,
|
|
"Error handling SCRAM-SHA final server response %s",
|
|
response)
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func parseSidAndData(authHeader string) (string, string, error) {
|
|
sidIndex := strings.Index(authHeader, SidPrefix)
|
|
if sidIndex < 0 {
|
|
return "", "", errors.Errorf("Cannot find %s in %s",
|
|
SidPrefix, authHeader)
|
|
}
|
|
|
|
sidEndIndex := strings.Index(authHeader, ",")
|
|
if sidEndIndex < 0 {
|
|
return "", "", errors.Errorf("Cannot find ',' in %s",
|
|
authHeader)
|
|
}
|
|
|
|
sid := authHeader[sidIndex+len(SidPrefix) : sidEndIndex]
|
|
|
|
dataIndex := strings.Index(authHeader, DataPrefix)
|
|
if dataIndex < 0 {
|
|
return "", "", errors.Errorf("Cannot find %s in %s",
|
|
DataPrefix, authHeader)
|
|
}
|
|
|
|
data, err := decode(authHeader[dataIndex+len(DataPrefix):])
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
return sid, data, nil
|
|
}
|