package pieces

import (
	"math/bits"

	"golang.org/x/xerrors"

	"github.com/minio/sha256-simd"

	"fil_integrate/build/fr32"
	"fil_integrate/build/state-types/abi"
)

const NODE_SIZE = 32

var a = false

func nextPowerOfTwo(v uint64) uint64 {
	v--
	v |= v >> 1
	v |= v >> 2
	v |= v >> 4
	v |= v >> 8
	v |= v >> 16
	v++
	return v
}

func GeneratePieceCommitmentFast(data []byte, unpad uint64) ([32]byte, error) {
	var result [32]byte

	inLen := paddedSize(unpad)

	in := make([]byte, inLen)
	copy(in, data)
	pow2 := nextPowerOfTwo(unpad)

	out := make([]byte, pow2)

	fr32.Pad(in, out)

	// r, err := MerkleTreeRecurse(out)
	r, err := MerkleTreeLoop(out)
	if err != nil {
		return [32]byte{}, err
	}
	copy(result[:], r)
	return result, nil
}

func MerkleTreeRecurse(D []byte) ([]byte, error) {
	n := uint64(len(D))
	//叶结点，直接返回
	if n < 32 {
		return D, xerrors.Errorf("can not generate the merkle tree")
	}
	if n == 32 {
		return D, nil
	}
	k := len(D) / 2
	h := sha256.New()
	//求左子树
	x, err := MerkleTreeRecurse(D[0:k])
	if err != nil {
		return nil, err
	}
	// 修剪到fr域
	trim_to_fr32(x)
	h.Write(x[:])
	//求右子树
	x, err = MerkleTreeRecurse(D[k:n])
	if err != nil {
		return nil, err
	}
	trim_to_fr32(x)
	h.Write(x[:])
	//得到哈希结果
	res := h.Sum(nil)
	trim_to_fr32(res)
	return res, nil
}

func MerkleTreeLoop(D []byte) ([]byte, error) {
	n := uint64(len(D))

	if n != nextPowerOfTwo(n) {
		return nil, xerrors.Errorf("can not generate the merkle tree")
	}
	for lenth := uint64(32); lenth < n; lenth <<= 1 {
		for index := uint64(0); index < n; {
			windex := index
			h := sha256.New()

			// write left child
			trim_to_fr32(D[index : index+32])
			h.Write(D[index : index+32])
			index += lenth
			// write right child
			trim_to_fr32(D[index : index+32])
			h.Write(D[index : index+32])
			index += lenth

			res := h.Sum(nil)
			copy(D[windex:windex+32], res)
		}
	}
	trim_to_fr32(D[:32])
	return D[:32], nil
}

func trim_to_fr32(data []byte) {
	// strip last two bits, to ensure result is in Fr.
	data[31] &= 0b0011_1111
}

func paddedSize(size uint64) abi.UnpaddedPieceSize {
	if size <= 127 {
		return abi.UnpaddedPieceSize(127)
	}

	// round to the nearest 127-divisible, find out fr32-padded size
	paddedPieceSize := (size + 126) / 127 * 128

	// round up if not power of 2
	if bits.OnesCount64(paddedPieceSize) != 1 {
		paddedPieceSize = 1 << uint(64-bits.LeadingZeros64(paddedPieceSize))
	}

	// get the unpadded size of the now-determind piece
	return abi.PaddedPieceSize(paddedPieceSize).Unpadded()
}
