package big import ( "encoding/json" "fmt" "math/big" ) // BigIntMaxSerializedLen is the max length of a byte slice representing a CBOR serialized big. const BigIntMaxSerializedLen = 128 type Int struct { *big.Int } func NewInt(i int64) Int { return Int{big.NewInt(0).SetInt64(i)} } func NewIntUnsigned(i uint64) Int { return Int{big.NewInt(0).SetUint64(i)} } func NewFromGo(i *big.Int) Int { return Int{big.NewInt(0).Set(i)} } func Zero() Int { return NewInt(0) } // PositiveFromUnsignedBytes interprets b as the bytes of a big-endian unsigned // integer and returns a positive Int with this absolute value. func PositiveFromUnsignedBytes(b []byte) Int { i := big.NewInt(0).SetBytes(b) return Int{i} } // MustFromString convers dec string into big integer and panics if conversion // is not sucessful. func MustFromString(s string) Int { v, err := FromString(s) if err != nil { panic(err) } return v } func FromString(s string) (Int, error) { v, ok := big.NewInt(0).SetString(s, 10) if !ok { return Int{}, fmt.Errorf("failed to parse string as a big int") } return Int{v}, nil } func (bi Int) Copy() Int { return Int{Int: new(big.Int).Set(bi.Int)} } func Product(ints ...Int) Int { p := NewInt(1) for _, i := range ints { p = Mul(p, i) } return p } func Mul(a, b Int) Int { return Int{big.NewInt(0).Mul(a.Int, b.Int)} } func Div(a, b Int) Int { return Int{big.NewInt(0).Div(a.Int, b.Int)} } func Mod(a, b Int) Int { return Int{big.NewInt(0).Mod(a.Int, b.Int)} } func Add(a, b Int) Int { return Int{big.NewInt(0).Add(a.Int, b.Int)} } func Sum(ints ...Int) Int { sum := Zero() for _, i := range ints { sum = Add(sum, i) } return sum } func Subtract(num1 Int, ints ...Int) Int { sub := num1 for _, i := range ints { sub = Sub(sub, i) } return sub } func Sub(a, b Int) Int { return Int{big.NewInt(0).Sub(a.Int, b.Int)} } // Returns a**e unless e <= 0 (in which case returns 1). func Exp(a Int, e Int) Int { return Int{big.NewInt(0).Exp(a.Int, e.Int, nil)} } // Returns x << n func Lsh(a Int, n uint) Int { return Int{big.NewInt(0).Lsh(a.Int, n)} } // Returns x >> n func Rsh(a Int, n uint) Int { return Int{big.NewInt(0).Rsh(a.Int, n)} } func BitLen(a Int) uint { return uint(a.Int.BitLen()) } func Max(x, y Int) Int { // taken from max.Max() if x.Equals(Zero()) && x.Equals(y) { if x.Sign() != 0 { return y } return x } if x.GreaterThan(y) { return x } return y } func Min(x, y Int) Int { // taken from max.Min() if x.Equals(Zero()) && x.Equals(y) { if x.Sign() != 0 { return x } return y } if x.LessThan(y) { return x } return y } func Cmp(a, b Int) int { return a.Int.Cmp(b.Int) } // LessThan returns true if bi < o func (bi Int) LessThan(o Int) bool { return Cmp(bi, o) < 0 } // LessThanEqual returns true if bi <= o func (bi Int) LessThanEqual(o Int) bool { return bi.LessThan(o) || bi.Equals(o) } // GreaterThan returns true if bi > o func (bi Int) GreaterThan(o Int) bool { return Cmp(bi, o) > 0 } // GreaterThanEqual returns true if bi >= o func (bi Int) GreaterThanEqual(o Int) bool { return bi.GreaterThan(o) || bi.Equals(o) } // Neg returns the negative of bi. func (bi Int) Neg() Int { return Int{big.NewInt(0).Neg(bi.Int)} } // Abs returns the absolute value of bi. func (bi Int) Abs() Int { if bi.GreaterThanEqual(Zero()) { return bi.Copy() } return bi.Neg() } // Equals returns true if bi == o func (bi Int) Equals(o Int) bool { return Cmp(bi, o) == 0 } func (bi *Int) MarshalJSON() ([]byte, error) { if bi.Int == nil { zero := Zero() return json.Marshal(zero) } return json.Marshal(bi.String()) } func (bi *Int) UnmarshalJSON(b []byte) error { var s string if err := json.Unmarshal(b, &s); err != nil { return err } i, ok := big.NewInt(0).SetString(s, 10) if !ok { return fmt.Errorf("failed to parse big string: '%s'", string(b)) } bi.Int = i return nil } func (bi *Int) Bytes() ([]byte, error) { if bi.Int == nil { return []byte{}, fmt.Errorf("failed to convert to bytes, big is nil") } switch { case bi.Sign() > 0: return append([]byte{0}, bi.Int.Bytes()...), nil case bi.Sign() < 0: return append([]byte{1}, bi.Int.Bytes()...), nil default: // bi.Sign() == 0: return []byte{}, nil } } func FromBytes(buf []byte) (Int, error) { if len(buf) == 0 { return NewInt(0), nil } var negative bool switch buf[0] { case 0: negative = false case 1: negative = true default: return Zero(), fmt.Errorf("big int prefix should be either 0 or 1, got %d", buf[0]) } i := big.NewInt(0).SetBytes(buf[1:]) if negative { i.Neg(i) } return Int{i}, nil } func (bi *Int) MarshalBinary() ([]byte, error) { if bi.Int == nil { zero := Zero() return zero.Bytes() } return bi.Bytes() } func (bi *Int) UnmarshalBinary(buf []byte) error { i, err := FromBytes(buf) if err != nil { return err } *bi = i return nil } func (bi *Int) IsZero() bool { return bi.Int.Sign() == 0 } func (bi *Int) Nil() bool { return bi.Int == nil } func (bi *Int) NilOrZero() bool { return bi.Int == nil || bi.Int.Sign() == 0 }