diff options
Diffstat (limited to 'lib/mpi/mpicoder.c')
| -rw-r--r-- | lib/mpi/mpicoder.c | 249 | 
1 files changed, 81 insertions, 168 deletions
diff --git a/lib/mpi/mpicoder.c b/lib/mpi/mpicoder.c index 747606f9e4a3..c6272ae2015e 100644 --- a/lib/mpi/mpicoder.c +++ b/lib/mpi/mpicoder.c @@ -21,6 +21,7 @@  #include <linux/bitops.h>  #include <linux/count_zeros.h>  #include <linux/byteorder/generic.h> +#include <linux/scatterlist.h>  #include <linux/string.h>  #include "mpi-internal.h" @@ -50,9 +51,7 @@ MPI mpi_read_raw_data(const void *xbuffer, size_t nbytes)  		return NULL;  	}  	if (nbytes > 0) -		nbits -= count_leading_zeros(buffer[0]); -	else -		nbits = 0; +		nbits -= count_leading_zeros(buffer[0]) - (BITS_PER_LONG - 8);  	nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);  	val = mpi_alloc(nlimbs); @@ -82,50 +81,30 @@ EXPORT_SYMBOL_GPL(mpi_read_raw_data);  MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread)  {  	const uint8_t *buffer = xbuffer; -	int i, j; -	unsigned nbits, nbytes, nlimbs, nread = 0; -	mpi_limb_t a; -	MPI val = NULL; +	unsigned int nbits, nbytes; +	MPI val;  	if (*ret_nread < 2) -		goto leave; +		return ERR_PTR(-EINVAL);  	nbits = buffer[0] << 8 | buffer[1];  	if (nbits > MAX_EXTERN_MPI_BITS) {  		pr_info("MPI: mpi too large (%u bits)\n", nbits); -		goto leave; +		return ERR_PTR(-EINVAL);  	} -	buffer += 2; -	nread = 2;  	nbytes = DIV_ROUND_UP(nbits, 8); -	nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB); -	val = mpi_alloc(nlimbs); -	if (!val) -		return NULL; -	i = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB; -	i %= BYTES_PER_MPI_LIMB; -	val->nbits = nbits; -	j = val->nlimbs = nlimbs; -	val->sign = 0; -	for (; j > 0; j--) { -		a = 0; -		for (; i < BYTES_PER_MPI_LIMB; i++) { -			if (++nread > *ret_nread) { -				printk -				    ("MPI: mpi larger than buffer nread=%d ret_nread=%d\n", -				     nread, *ret_nread); -				goto leave; -			} -			a <<= 8; -			a |= *buffer++; -		} -		i = 0; -		val->d[j - 1] = a; +	if (nbytes + 2 > *ret_nread) { +		pr_info("MPI: mpi larger than buffer nbytes=%u ret_nread=%u\n", +				nbytes, *ret_nread); +		return ERR_PTR(-EINVAL);  	} -leave: -	*ret_nread = nread; +	val = mpi_read_raw_data(buffer + 2, nbytes); +	if (!val) +		return ERR_PTR(-ENOMEM); + +	*ret_nread = nbytes + 2;  	return val;  }  EXPORT_SYMBOL_GPL(mpi_read_from_buffer); @@ -250,82 +229,6 @@ void *mpi_get_buffer(MPI a, unsigned *nbytes, int *sign)  }  EXPORT_SYMBOL_GPL(mpi_get_buffer); -/**************** - * Use BUFFER to update MPI. - */ -int mpi_set_buffer(MPI a, const void *xbuffer, unsigned nbytes, int sign) -{ -	const uint8_t *buffer = xbuffer, *p; -	mpi_limb_t alimb; -	int nlimbs; -	int i; - -	nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB); -	if (RESIZE_IF_NEEDED(a, nlimbs) < 0) -		return -ENOMEM; -	a->sign = sign; - -	for (i = 0, p = buffer + nbytes - 1; p >= buffer + BYTES_PER_MPI_LIMB;) { -#if BYTES_PER_MPI_LIMB == 4 -		alimb = (mpi_limb_t) *p--; -		alimb |= (mpi_limb_t) *p-- << 8; -		alimb |= (mpi_limb_t) *p-- << 16; -		alimb |= (mpi_limb_t) *p-- << 24; -#elif BYTES_PER_MPI_LIMB == 8 -		alimb = (mpi_limb_t) *p--; -		alimb |= (mpi_limb_t) *p-- << 8; -		alimb |= (mpi_limb_t) *p-- << 16; -		alimb |= (mpi_limb_t) *p-- << 24; -		alimb |= (mpi_limb_t) *p-- << 32; -		alimb |= (mpi_limb_t) *p-- << 40; -		alimb |= (mpi_limb_t) *p-- << 48; -		alimb |= (mpi_limb_t) *p-- << 56; -#else -#error please implement for this limb size. -#endif -		a->d[i++] = alimb; -	} -	if (p >= buffer) { -#if BYTES_PER_MPI_LIMB == 4 -		alimb = *p--; -		if (p >= buffer) -			alimb |= (mpi_limb_t) *p-- << 8; -		if (p >= buffer) -			alimb |= (mpi_limb_t) *p-- << 16; -		if (p >= buffer) -			alimb |= (mpi_limb_t) *p-- << 24; -#elif BYTES_PER_MPI_LIMB == 8 -		alimb = (mpi_limb_t) *p--; -		if (p >= buffer) -			alimb |= (mpi_limb_t) *p-- << 8; -		if (p >= buffer) -			alimb |= (mpi_limb_t) *p-- << 16; -		if (p >= buffer) -			alimb |= (mpi_limb_t) *p-- << 24; -		if (p >= buffer) -			alimb |= (mpi_limb_t) *p-- << 32; -		if (p >= buffer) -			alimb |= (mpi_limb_t) *p-- << 40; -		if (p >= buffer) -			alimb |= (mpi_limb_t) *p-- << 48; -		if (p >= buffer) -			alimb |= (mpi_limb_t) *p-- << 56; -#else -#error please implement for this limb size. -#endif -		a->d[i++] = alimb; -	} -	a->nlimbs = i; - -	if (i != nlimbs) { -		pr_emerg("MPI: mpi_set_buffer: Assertion failed (%d != %d)", i, -		       nlimbs); -		BUG(); -	} -	return 0; -} -EXPORT_SYMBOL_GPL(mpi_set_buffer); -  /**   * mpi_write_to_sgl() - Funnction exports MPI to an sgl (msb first)   * @@ -335,16 +238,13 @@ EXPORT_SYMBOL_GPL(mpi_set_buffer);   * @a:		a multi precision integer   * @sgl:	scatterlist to write to. Needs to be at least   *		mpi_get_size(a) long. - * @nbytes:	in/out param - it has the be set to the maximum number of - *		bytes that can be written to sgl. This has to be at least - *		the size of the integer a. On return it receives the actual - *		length of the data written on success or the data that would - *		be written if buffer was too small. + * @nbytes:	the number of bytes to write.  Leading bytes will be + *		filled with zero.   * @sign:	if not NULL, it will be set to the sign of a.   *   * Return:	0 on success or error code in case of error   */ -int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes, +int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned nbytes,  		     int *sign)  {  	u8 *p, *p2; @@ -356,55 +256,60 @@ int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned *nbytes,  #error please implement for this limb size.  #endif  	unsigned int n = mpi_get_size(a); -	int i, x, y = 0, lzeros, buf_len; - -	if (!nbytes) -		return -EINVAL; +	struct sg_mapping_iter miter; +	int i, x, buf_len; +	int nents;  	if (sign)  		*sign = a->sign; -	lzeros = count_lzeros(a); - -	if (*nbytes < n - lzeros) { -		*nbytes = n - lzeros; +	if (nbytes < n)  		return -EOVERFLOW; -	} -	*nbytes = n - lzeros; -	buf_len = sgl->length; -	p2 = sg_virt(sgl); +	nents = sg_nents_for_len(sgl, nbytes); +	if (nents < 0) +		return -EINVAL; -	for (i = a->nlimbs - 1 - lzeros / BYTES_PER_MPI_LIMB, -			lzeros %= BYTES_PER_MPI_LIMB; -		i >= 0; i--) { +	sg_miter_start(&miter, sgl, nents, SG_MITER_ATOMIC | SG_MITER_TO_SG); +	sg_miter_next(&miter); +	buf_len = miter.length; +	p2 = miter.addr; + +	while (nbytes > n) { +		i = min_t(unsigned, nbytes - n, buf_len); +		memset(p2, 0, i); +		p2 += i; +		nbytes -= i; + +		buf_len -= i; +		if (!buf_len) { +			sg_miter_next(&miter); +			buf_len = miter.length; +			p2 = miter.addr; +		} +	} + +	for (i = a->nlimbs - 1; i >= 0; i--) {  #if BYTES_PER_MPI_LIMB == 4 -		alimb = cpu_to_be32(a->d[i]); +		alimb = a->d[i] ? cpu_to_be32(a->d[i]) : 0;  #elif BYTES_PER_MPI_LIMB == 8 -		alimb = cpu_to_be64(a->d[i]); +		alimb = a->d[i] ? cpu_to_be64(a->d[i]) : 0;  #else  #error please implement for this limb size.  #endif -		if (lzeros) { -			y = lzeros; -			lzeros = 0; -		} +		p = (u8 *)&alimb; -		p = (u8 *)&alimb + y; - -		for (x = 0; x < sizeof(alimb) - y; x++) { -			if (!buf_len) { -				sgl = sg_next(sgl); -				if (!sgl) -					return -EINVAL; -				buf_len = sgl->length; -				p2 = sg_virt(sgl); -			} +		for (x = 0; x < sizeof(alimb); x++) {  			*p2++ = *p++; -			buf_len--; +			if (!--buf_len) { +				sg_miter_next(&miter); +				buf_len = miter.length; +				p2 = miter.addr; +			}  		} -		y = 0;  	} + +	sg_miter_stop(&miter);  	return 0;  }  EXPORT_SYMBOL_GPL(mpi_write_to_sgl); @@ -424,19 +329,23 @@ EXPORT_SYMBOL_GPL(mpi_write_to_sgl);   */  MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)  { -	struct scatterlist *sg; -	int x, i, j, z, lzeros, ents; +	struct sg_mapping_iter miter;  	unsigned int nbits, nlimbs; +	int x, j, z, lzeros, ents; +	unsigned int len; +	const u8 *buff;  	mpi_limb_t a;  	MPI val = NULL; -	lzeros = 0; -	ents = sg_nents(sgl); +	ents = sg_nents_for_len(sgl, nbytes); +	if (ents < 0) +		return NULL; -	for_each_sg(sgl, sg, ents, i) { -		const u8 *buff = sg_virt(sg); -		int len = sg->length; +	sg_miter_start(&miter, sgl, ents, SG_MITER_ATOMIC | SG_MITER_FROM_SG); +	lzeros = 0; +	len = 0; +	while (nbytes > 0) {  		while (len && !*buff) {  			lzeros++;  			len--; @@ -446,12 +355,14 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)  		if (len && *buff)  			break; -		ents--; +		sg_miter_next(&miter); +		buff = miter.addr; +		len = miter.length; +  		nbytes -= lzeros;  		lzeros = 0;  	} -	sgl = sg;  	nbytes -= lzeros;  	nbits = nbytes * 8;  	if (nbits > MAX_EXTERN_MPI_BITS) { @@ -460,8 +371,7 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)  	}  	if (nbytes > 0) -		nbits -= count_leading_zeros(*(u8 *)(sg_virt(sgl) + lzeros)) - -			(BITS_PER_LONG - 8); +		nbits -= count_leading_zeros(*buff) - (BITS_PER_LONG - 8);  	nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);  	val = mpi_alloc(nlimbs); @@ -480,21 +390,24 @@ MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)  	z = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;  	z %= BYTES_PER_MPI_LIMB; -	for_each_sg(sgl, sg, ents, i) { -		const u8 *buffer = sg_virt(sg) + lzeros; -		int len = sg->length - lzeros; - +	for (;;) {  		for (x = 0; x < len; x++) {  			a <<= 8; -			a |= *buffer++; +			a |= *buff++;  			if (((z + x + 1) % BYTES_PER_MPI_LIMB) == 0) {  				val->d[j--] = a;  				a = 0;  			}  		}  		z += x; -		lzeros = 0; + +		if (!sg_miter_next(&miter)) +			break; + +		buff = miter.addr; +		len = miter.length;  	} +  	return val;  }  EXPORT_SYMBOL_GPL(mpi_read_raw_from_sgl);  | 
