use core::{cmp::Ordering, convert::TryInto, ops::{Div, Rem}};
use ref_cast::RefCast;
use crate::{Digit, DoubleDigit, Odd, Result, Unsigned, Wrapping};
use crate::numbers::{Array, Bits, Number, NumberMut};
pub fn wrapping_invert_odd<const D: usize, const E: usize>(x: &Odd<D, E>) -> Odd<D, E> {
#[allow(non_snake_case)]
let T = (D + E) * (Digit::BITS.trailing_zeros() as usize);
let x: &Unsigned<D, E> = &*x;
let mut y = Unsigned::from(1);
let two: Unsigned<D, E> = Unsigned::from(2);
for _ in 1..=T {
y = y.wrapping_mul(&(two.wrapping_sub(&x.wrapping_mul(&y))));
}
Odd(y)
}
pub fn wrapping_invert<const D: usize, const E: usize>(unsigned: &Unsigned<D, E>) -> Result<Unsigned<D, E>> {
let odd: &Odd<D, E> = unsigned.try_into()?;
Ok(wrapping_invert_odd(odd).into())
}
#[inline]
pub fn div_digits(hi: Digit, lo: Digit, divisor: Digit) -> (Digit, Digit) {
let x = ((hi as DoubleDigit) << Digit::BITS) + lo as DoubleDigit;
let divisor = divisor as DoubleDigit;
let q = x / divisor;
debug_assert!(q <= Digit::MAX as _);
let r = x % divisor;
debug_assert!(r <= Digit::MAX as _);
(q as Digit, r as Digit)
}
pub fn div_rem_assign_digit<N: NumberMut>(number: &mut N, modulus: Digit) -> Digit {
let mut remainder = 0;
#[cfg(not(feature = "ct-maybe"))]
let l = number.significant_digits().len();
#[cfg(feature = "ct-maybe")]
let l = N::DIGITS;
for digit in number[..l].iter_mut().rev() {
let (quotient, r) = div_digits(remainder, *digit, modulus);
*digit = quotient;
remainder = r;
}
remainder
}
pub fn generic_div_rem<T, const D: usize, const E: usize>(x: &T, n: &Unsigned<D, E>) -> (T, Unsigned<D, E>)
where
T: NumberMut + PartialOrd,
T: core::ops::ShrAssign<usize>,
for<'a> &'a T: core::ops::Shl<usize, Output = T>,
for<'a> Wrapping<T>: core::ops::SubAssign<&'a Unsigned<D, E>>,
for<'a> Wrapping<T>: core::ops::SubAssign<&'a T>,
{
if x.is_zero() {
return (T::zero(), Unsigned::zero())
}
if n.is_digit() {
let n = n[0];
let mut div = x.clone();
if n == 1 {
return (div, Unsigned::zero());
} else {
let rem = div_rem_assign_digit(&mut div, n);
return (div, rem.into());
}
}
match n.partial_cmp(x).unwrap() {
Ordering::Greater => return (T::zero(), Unsigned::from_slice(x.significant_digits())),
Ordering::Equal => return (T::one(), Unsigned::zero()),
Ordering::Less => {}
}
let shift_bits = n.leading_digit().unwrap().leading_zeros() as usize;
let mut r: T = x << shift_bits;
let n: Unsigned<D, E> = n << shift_bits;
let q_len = x.significant_digits().len() - n.significant_digits().len() + 1;
let mut q = T::zero();
let mut trial = T::zero();
for j in (0..q_len).rev() {
let offset = j + n.significant_digits().len() - 1;
let r_len = r.significant_digits().len();
if offset >= r.significant_digits().len() {
continue;
}
trial.set_zero();
trial[..r_len - offset].copy_from_slice(&r[offset..r_len]);
div_rem_assign_digit(&mut trial, n.leading_digit().unwrap());
let mut prod = super::multiply::wrapping_mul(&trial, &n);
while prod > T::from_slice(&r[j..]) {
*Wrapping::ref_cast_mut(&mut trial) -= &T::one();
*Wrapping::ref_cast_mut(&mut prod) -= &n;
}
super::add::wrapping_add_assign(&mut q[j..], trial.significant_digits());
super::subtract::sub_assign_borrow(&mut r[j..], prod.significant_digits());
}
debug_assert!(n > r);
r >>= shift_bits;
(q, r.to_unsigned().unwrap())
}
impl<const D: usize, const E: usize> Wrapping<Unsigned<D, E>> {
pub fn inv(&self) -> Result<Self> {
wrapping_invert(&self.0).map(Wrapping)
}
}
impl<const D: usize, const E: usize> Unsigned<D, E> {
pub fn wrapping_inv(&self) -> Result<Self> {
wrapping_invert(&self)
}
}
impl<'a, const D: usize, const E: usize, const F: usize, const G: usize> Div<&'a Unsigned<F, G>> for &'a Unsigned<D, E> {
type Output = Unsigned<D, E>;
fn div(self, modulus: &'a Unsigned<F, G>) -> Self::Output {
generic_div_rem(self, modulus).0
}
}
impl<'a, const D: usize, const E: usize, const F: usize, const G: usize> Div<&'a Unsigned<F, G>> for Unsigned<D, E> {
type Output = Unsigned<D, E>;
fn div(self, modulus: &'a Unsigned<F, G>) -> Self::Output {
generic_div_rem(&self, modulus).0
}
}
impl<'a, const D: usize, const E: usize, const F: usize, const G: usize> Div<Unsigned<F, G>> for &'a Unsigned<D, E> {
type Output = Unsigned<D, E>;
fn div(self, modulus: Unsigned<F, G>) -> Self::Output {
generic_div_rem(self, &modulus).0
}
}
impl<const D: usize, const E: usize, const F: usize, const G: usize> Div<Unsigned<F, G>> for Unsigned<D, E> {
type Output = Unsigned<D, E>;
fn div(self, modulus: Unsigned<F, G>) -> Self::Output {
generic_div_rem(&self, &modulus).0
}
}
impl<'a, const D: usize, const E: usize, const F: usize, const G: usize> Rem<&'a Unsigned<F, G>> for &'a Unsigned<D, E> {
type Output = Unsigned<F, G>;
fn rem(self, modulus: &'a Unsigned<F, G>) -> Self::Output {
let (_quotient, remainder) = generic_div_rem(self, modulus);
remainder
}
}
impl<'a, const D: usize, const E: usize, const F: usize, const G: usize> Rem<&'a Unsigned<F, G>> for Unsigned<D, E> {
type Output = Unsigned<F, G>;
fn rem(self, modulus: &'a Unsigned<F, G>) -> Self::Output {
generic_div_rem(&self, modulus).1
}
}
impl<'a, const D: usize, const E: usize, const F: usize, const G: usize> Rem<Unsigned<F, G>> for &'a Unsigned<D, E> {
type Output = Unsigned<F, G>;
fn rem(self, modulus: Unsigned<F, G>) -> Self::Output {
generic_div_rem(self, &modulus).1
}
}
impl<'a, const D: usize, const E: usize, const F: usize, const G: usize> Rem<Unsigned<F, G>> for Unsigned<D, E> {
type Output = Unsigned<F, G>;
fn rem(self, modulus: Unsigned<F, G>) -> Self::Output {
generic_div_rem(&self, &modulus).1
}
}
impl<'a, const D: usize, const E: usize, const F: usize, const G: usize, const L: usize> Rem<&'a Unsigned<F, G>> for &'a Array<D, E, L> {
type Output = Unsigned<F, G>;
fn rem(self, modulus: &'a Unsigned<F, G>) -> Self::Output {
generic_div_rem(self, modulus).1
}
}
#[cfg(test)]
mod test {
use super::*;
pub const N1: Digit = -1i64 as Digit;
pub const N2: Digit = -2i64 as Digit;
pub const M: Digit = Digit::MAX;
macro_rules! assert_op {
($left:ident $op:tt $right:ident == $expected:expr) => {
assert_eq!((&$left) $op (&$right), $expected);
assert_eq!((&$left) $op $right.clone(), $expected);
assert_eq!($left.clone() $op (&$right), $expected);
assert_eq!($left.clone() $op $right.clone(), $expected);
};
}
pub const MUL_TRIPLES: &'static [(&'static [Digit], &'static [Digit], &'static [Digit])] = &[
(&[], &[], &[]),
(&[], &[1], &[]),
(&[2], &[], &[]),
(&[1], &[1], &[1]),
(&[2], &[3], &[6]),
(&[1], &[1, 1, 1], &[1, 1, 1]),
(&[1, 2, 3], &[3], &[3, 6, 9]),
(&[1, 1, 1], &[N1], &[N1, N1, N1]),
(&[1, 2, 3], &[N1], &[N1, N2, N2, 2]),
(&[1, 2, 3, 4], &[N1], &[N1, N2, N2, N2, 3]),
(&[N1], &[N1], &[1, N2]),
(&[N1, N1], &[N1], &[1, N1, N2]),
(&[N1, N1, N1], &[N1], &[1, N1, N1, N2]),
(&[N1, N1, N1, N1], &[N1], &[1, N1, N1, N1, N2]),
(&[M / 2 + 1], &[2], &[0, 1]),
(&[0, M / 2 + 1], &[2], &[0, 0, 1]),
(&[1, 2], &[1, 2, 3], &[1, 4, 7, 6]),
(&[N1, N1], &[N1, N1, N1], &[1, 0, N1, N2, N1]),
(&[N1, N1, N1], &[N1, N1, N1, N1], &[1, 0, 0, N1, N2, N1, N1]),
(&[0, 0, 1], &[1, 2, 3], &[0, 0, 1, 2, 3]),
(&[0, 0, 1], &[0, 0, 0, 1], &[0, 0, 0, 0, 0, 1]),
];
pub const DIV_REM_QUADRUPLES: &'static [(
&'static [Digit],
&'static [Digit],
&'static [Digit],
&'static [Digit],
)] = &[
(&[1], &[2], &[], &[1]),
(&[3], &[2], &[1], &[1]),
(&[1, 1], &[2], &[M / 2 + 1], &[1]),
(&[1, 1, 1], &[2], &[M / 2 + 1, M / 2 + 1], &[1]),
(&[0, 1], &[N1], &[1], &[1]),
(&[N1, N1], &[N2], &[2, 1], &[3]),
];
#[test]
fn rem_of_1() {
use crate::fixtures::*;
let short1 = Short256::from(1);
let p = p256().into_unsigned();
let remainder = &short1 % &p;
assert_eq!(remainder, short1);
}
#[test]
fn test_div_rem() {
use crate::Short;
for case in MUL_TRIPLES.iter() {
let (a_vec, b_vec, c_vec) = *case;
let a = Short::<7>::from_slice(a_vec);
let b = Short::<7>::from_slice(b_vec);
let c = Short::<7>::from_slice(c_vec);
if !a.is_zero() {
assert_op!(c / a == b);
assert_op!(c % a == Short::<7>::zero());
assert_eq!(generic_div_rem(&c, &a), (b.clone(), Unsigned::zero()));
}
if !b.is_zero() {
assert_op!(c / b == a);
assert_op!(c % b == Short::<7>::zero());
assert_eq!(generic_div_rem(&c, &b), (a.clone(), Unsigned::zero()));
}
}
for case in DIV_REM_QUADRUPLES.iter() {
let (a_vec, b_vec, c_vec, d_vec) = *case;
let a = Short::<7>::from_slice(a_vec);
let b = Short::<7>::from_slice(b_vec);
let c = Short::<7>::from_slice(c_vec);
let d = Short::<7>::from_slice(d_vec);
if !b.is_zero() {
assert_op!(a / b == c);
assert_op!(a % b == d);
assert!(generic_div_rem(&a, &b) == (c, d));
}
}
}
#[test]
fn test_invert() {
use hex_literal::hex;
use crate::{fixtures::Long256, Long};
let x = Long256::from_bytes(&hex!(
"756a33dea26163dfae8303747b3db15dc071fc5a0d75de209881a570678b33bb"));
let maybe_inverse = wrapping_invert(&x).unwrap();
let maybe_one = crate::arithmetic::multiply::wrapping_mul(&maybe_inverse, &x);
assert_eq!(maybe_one, Long::<2>::one());
let x = Long::<2>::from_slice(&[0x2, 0x1]);
assert!(wrapping_invert(&x).is_err());
}
}