package customizations

import (
	"context"
	"crypto/sha256"
	"encoding/hex"
	"fmt"
	"io"

	"github.com/aws/smithy-go"
	"github.com/aws/smithy-go/middleware"
	smithyhttp "github.com/aws/smithy-go/transport/http"
)

// AddTreeHashMiddleware adds middleware needed to automatically
// calculate Glacier's required checksum headers.
func AddTreeHashMiddleware(stack *middleware.Stack) error {
	return stack.Finalize.Add(&TreeHash{}, middleware.Before)
}

// TreeHash provides the middleware that will automatically
// set the sha256 and tree hash headers if they have not already been
// set.
type TreeHash struct{}

// ID returns the middleware ID.
func (*TreeHash) ID() string {
	return "Glacier:TreeHash"
}

// HandleFinalize implements the finalize middleware handler method
func (*TreeHash) HandleFinalize(
	ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
) (
	output middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
	req, ok := input.Request.(*smithyhttp.Request)
	if !ok {
		return output, metadata, &smithy.SerializationError{
			Err: fmt.Errorf("unknown request type %T", input.Request),
		}
	}

	if err := addChecksum(req); err != nil {
		return output, metadata, &smithy.SerializationError{Err: err}
	}

	return next.HandleFinalize(ctx, input)
}

func addChecksum(req *smithyhttp.Request) error {
	if req.GetStream() == nil || req.Header.Get("X-Amz-Sha256-Tree-Hash") != "" {
		return nil
	}

	if !req.IsStreamSeekable() {
		return fmt.Errorf("glacier content-sha26 and tree hash can only be automatically computed if the request body is seekable")
	}

	h := computeHashes(req.GetStream())
	if err := req.RewindStream(); err != nil {
		return err
	}

	hstr := hex.EncodeToString(h.TreeHash)
	req.Header.Set("X-Amz-Sha256-Tree-Hash", hstr)

	hLstr := hex.EncodeToString(h.LinearHash)
	req.Header.Set("X-Amz-Content-Sha256", hLstr)

	return nil
}

// Hash contains information about the tree-hash and linear hash of a
// Glacier payload. This structure is generated by computeHashes().
type Hash struct {
	TreeHash   []byte
	LinearHash []byte
}

// computeHashes computes the tree-hash and linear hash of a reader r.
//
// Note that this does not perform seeks before or after, these must be done manually.
//
// See http://docs.aws.amazon.com/amazonglacier/latest/dev/checksum-calculations.html for more information.
func computeHashes(r io.Reader) Hash {
	const bufsize = 1024 * 1024

	buf := make([]byte, bufsize)
	var hashes [][]byte
	hsh := sha256.New()

	for {
		// Build leaf nodes in 1MB chunks
		n, err := io.ReadAtLeast(r, buf, bufsize)
		if n == 0 {
			break
		}

		tmpHash := sha256.Sum256(buf[:n])
		hashes = append(hashes, tmpHash[:])
		hsh.Write(buf[:n]) // Track linear hash while we're at it

		if err != nil {
			break // This is the last chunk
		}
	}

	return Hash{
		LinearHash: hsh.Sum(nil),
		TreeHash:   computeTreeHash(hashes),
	}
}

// computeTreeHash builds a tree hash root node given a slice of
// hashes. Glacier tree hash to be derived from SHA256 hashes of 1MB
// chucks of the data.
//
// See http://docs.aws.amazon.com/amazonglacier/latest/dev/checksum-calculations.html for more information.
func computeTreeHash(hashes [][]byte) []byte {
	if hashes == nil || len(hashes) == 0 {
		return nil
	}

	for len(hashes) > 1 {
		tmpHashes := [][]byte{}

		for i := 0; i < len(hashes); i += 2 {
			if i+1 <= len(hashes)-1 {
				tmpHash := append(append([]byte{}, hashes[i]...), hashes[i+1]...)
				tmpSum := sha256.Sum256(tmpHash)
				tmpHashes = append(tmpHashes, tmpSum[:])
			} else {
				tmpHashes = append(tmpHashes, hashes[i])
			}
		}

		hashes = tmpHashes
	}

	return hashes[0]
}
