[FAQ]如果我有一个很大很大的循环,怎样显著降低计算的时间呢?

比如,在python中可以采用二进制展开,在Julia可以怎么做呢?

# code
  def __rmul__(self, coefficient):

    coef = coefficient
    current = self
    result = self.__class__(None, None, self.a, self.b, True)
    while coef:
        if coef & 1:
            result += current
        current += current
        coef >>= 1
    return result

来一个具体的例子吧

定义在椭圆曲线上的点加法。 然后定义其数乘运算,只能采取连续相加的方法

n \times Point = Point + Point+Point+Point + ... +Point

但采用循环进行相加的时候,当n很大的时候,运行的时间就很久了。

import .Base: +, -, *, ^, /, ==
struct FieldElement
    Num::BigInt
    Prime::BigInt
    function FieldElement(Num, Prime)
        if Num < 0 || Num >= Prime
            return "Num $(Num) not in field range 0 to $(Prime)"
        else
            return new(Num, Prime)
        end
    end
end

#show FieldElement
function Base.show(io::IO, a::FieldElement)
    print(io, "FieldElement[$(a.Prime)]: $(a.Num)")
end
#FieldElement: ==
function ==(a::FieldElement, b::FieldElement)
    return a.Num == b.Num && a.Prime == b.Prime
end
#FieldElement: +, -
function +(a::FieldElement, b::FieldElement)
    if a.Prime != b.Prime
        return "cann't add in different fields"
    else
        return FieldElement(mod(a.Num+b.Num, a.Prime), a.Prime)
    end
end
function -(a::FieldElement, b::FieldElement)
    if a.Prime != b.Prime
        return "cann't - in different fields"
    else
        return FieldElement(mod(a.Num-b.Num, a.Prime), a.Prime)
    end
end
function -(a::FieldElement)
    return FieldElement(mod(-a.Num, a.Prime), a.Prime)
end
#FieldElement: *, ^
function *(a::FieldElement, b::FieldElement)
    if a.Prime != b.Prime
        return "cannot mul in different fields"
    else
        return FieldElement(mod(a.Num*b.Num, a.Prime), a.Prime)
    end
end
function *(a::BigInt, b::FieldElement)
    # no caef a <0
    return FieldElement(mod(a*b.Num, b.Prime), b.Prime)
end
function ^(a::FieldElement, b::BigInt)
    if b < 0
        while b < 0
            b += a.Prime - 1
        end
        return FieldElement(powermod(a.Num, b, a.Prime), a.Prime)
    else
        return FieldElement(powermod(a.Num, b, a.Prime), a.Prime)
    end
end
#FieldElement: /
function /(a::FieldElement, b::FieldElement)
    if a.Prime != b.Prime
        return "can div in different fields"
    else
        return a*FieldElement(powermod(b.Num, b.Prime-2, b.Prime), b.Prime)
    end
end
include("FieldElement.jl")
struct EPoint
    x::FieldElement
    y::FieldElement
    a::FieldElement
    b::FieldElement
    infinity::Bool
    function EPoint(x, y, a, b, infinity)
        if infinity
            return new(x, y, a, b, infinity)
        end
        if y^BigInt(2) != x^BigInt(3) + a * x + b
            return "Point not on the curve"
        else
            return new(x, y, a, b, infinity)
        end
    end
end

#show EPoint
function Base.show(io::IO, A::EPoint)
    if A.infinity
        print(io, "Point(infinity)")
    else
        print(io, "Point($(A.x.Num),$(A.y.Num))-FieldElement($(A.x.Prime))")
    end
end

#EPoint: ==
function ==(A::EPoint, B::EPoint)
    return A.x == B.x && A.y == B.y && A.a == B.a && A.b == B.b && A.infinity == B.infinity
end

#EPoint: add,
function +(A::EPoint, B::EPoint)
    Prime = A.a.Prime
    coef1, coef2 = FieldElement(3, Prime), FieldElement(2, Prime)
    sp_point = EPoint(FieldElement(0, Prime),FieldElement(0, Prime),A.a,A.b,true)

    if A.infinity && B.infinity != true
        return B
    end
    if B.infinity 
        return A
    end
    if A == B
        if A.y == FieldElement(0, A.y.Prime)
            return sp_point
        else
            k = (coef1*(A.x^BigInt(2)) + A.a)/(coef2*A.y)
            x₃ = k^BigInt(2) - coef2*A.x
            y₃ = k * (A.x - x₃) - A.y
            return EPoint(x₃, y₃, A.a, A.b, false)
        end
    else
        if A.x == B.x && A.y != B.y
            return sp_point
        else
            k = (A.y - B.y)/(A.x - B.x)
            x₃ = k^BigInt(2) - A.x - B.x
            y₃ = k * (A.x - x₃) - A.y
            return EPoint(x₃, y₃, A.a, A.b, false)
        end
    end
end

function *(a::BigInt, A::EPoint)
    Prime = A.a.Prime
    sum = EPoint(FieldElement(0, Prime),FieldElement(0, Prime),A.a,A.b,true)
    for _ in 1:a
        sum += A
    end
    return sum
end

p = 115792089237316195423570985008687907853269984665640564039457584007908834671663
Gx = 55066263022277343669578718895168534326250603453777594175500187360389116729240
Gy = 32670510020758816978083085130507043184471273380659243275938904335757337482424
N = 115792089237316195423570985008687907852837564279074904382605163141518161494337



x = FieldElement(Gx, p)
y = FieldElement(Gy, p)
seven = FieldElement(7, p)
z = FieldElement(0, p)
G = EPoint(x,y,z,seven,false)
using BenchmarkTools

@btime (BigInt(2)^BigInt(20))*G
#@btime BigInt(N) * G

你这个我确实不太了解,比如我感觉你这里的加法不能交换次序(不然可以用多线程),不过你可以看看文档里面的性能分析与性能建议。另外,你这里的好多地方return 类型不稳定。比如

#FieldElement: *, ^
function *(a::FieldElement, b::FieldElement)
    if a.Prime != b.Prime
        return "cannot mul in different fields" # 字符串
    else
        return FieldElement(mod(a.Num*b.Num, a.Prime), a.Prime) # FieldElement
    end
end

第一个建议是,该抛出异常的地方不要返回 String.

function +(a::FieldElement, b::FieldElement)
-    if a.Prime != b.Prime
-        return "cann't add in different fields"
-    else
-        return FieldElement(mod(a.Num+b.Num, a.Prime), a.Prime)
-    end
+    a.Prime == b.Prime || throw(ArgumentError("Can't add in different fields"))
+    FiledElement(mod(a.Num+b.Num, a.Prime), a.Prime))
end

第二个不一定有效的建议是:因为 +/- 等运算太基础了,所以即使是 a.Prime == b.Prime 这种判断也是额外的性能开销,所以避免这种运算可能会带来很大的性能提升。

struct FieldElement{P}
    Num::BigInt
    function FieldElement(Num, Prime)
            return new{Prime}(Num)
    end
end

+(a::FieldElement{P}, b::FieldElement{P}) where P = FiledElement{P}(mod(a.Num + b.Num, P))
+(a::FiledElement, b::FieldElement) = throw(ArgumentError("Can't add in different fields"))

并不确定将 P 模版化会不会带来性能提升。但有一点可以确定的是,模版化一定会对编译器带来很大的压力,因为 P 的取值可能性太多了。

第三个建议是,利用 simd 来调用CPU级别的并行。

function *(a::BigInt, A::EPoint)
    Prime = A.a.Prime
    sum = EPoint(FieldElement(0, Prime),FieldElement(0, Prime),A.a,A.b,true)
-   for _ in 1:a
+   @simd for _ in 1:a 
        sum += A
    end
    return sum
end

这里for循环内部的加法顺序可能会打乱,但应该不影响结果?

好的,谢谢你的建议

我去试试,谢谢了

@simd宏和之前的@threads宏有什么区别(提问)

@simd 是CPU级别的并行,这是指令集级别的,我们无法控制它,只能告诉CPU是否需要打开这一项。至于是否真的调用了由CPU来决定:@simd 只会对于非常简单的运算才会有效,对于那些带有 if 判断的一般是不会生效的。

@threads 是线程/协程级别的并行。

Distributed 是进程级别的并行。

2赞

京ICP备17009874号-2