关于使用PyCall需要注意的一些地方

调用PyCall使用一些简单的Python库时,一般不太需要考虑性能的问题,不过对于一些任务而言,性能就显得至关重要了。

大约一个多月前,PyCall库里增加了一个pycall!函数(见PR),用于提升某些场景下调用python函数的效率。不知为何README文档中并没有提及,不过benchmark是更新了的。执行benchmarks/callperf.jl可以看到调用效率上的差别:

pycall_legacy ()                        TrialEstimate(476.043 ns)
pycall ()                               TrialEstimate(474.159 ns)
pycall! ()                              TrialEstimate(404.471 ns)
_pycall! ()                             TrialEstimate(367.595 ns)
nprand_pywrapfn ()                      TrialEstimate(371.296 ns)
nprand_pywrapfn_noargs ()               TrialEstimate(365.872 ns)
pycall_legacy (1,)                      TrialEstimate(1.129 μs)
pycall (1,)                             TrialEstimate(1.055 μs)
pycall! (1,)                            TrialEstimate(1.105 μs)
_pycall! (1,)                           TrialEstimate(1.044 μs)
nprand_pywrapfn (1,)                    TrialEstimate(1.063 μs)
nprand_pywrapfn_noargs (1,)             TrialEstimate(962.069 ns)
pycall_legacy (1, 1)                    TrialEstimate(4.152 μs)
pycall (1, 1)                           TrialEstimate(1.181 μs)
pycall! (1, 1)                          TrialEstimate(1.268 μs)
_pycall! (1, 1)                         TrialEstimate(1.212 μs)
nprand_pywrapfn (1, 1)                  TrialEstimate(1.159 μs)
nprand_pywrapfn_noargs (1, 1)           TrialEstimate(983.563 ns)
pycall_legacy (1, 1, 1)                 TrialEstimate(5.803 μs)
pycall (1, 1, 1)                        TrialEstimate(1.323 μs)
pycall! (1, 1, 1)                       TrialEstimate(1.257 μs)
_pycall! (1, 1, 1)                      TrialEstimate(1.182 μs)
nprand_pywrapfn (1, 1, 1)               TrialEstimate(1.115 μs)
nprand_pywrapfn_noargs (1, 1, 1)        TrialEstimate(1.108 μs)
pycall_legacy 7*(1,1,...)               TrialEstimate(10.953 μs)
pycall 7*(1,1,...)                      TrialEstimate(1.665 μs)
pycall! 7*(1,1,...)                     TrialEstimate(1.466 μs)
_pycall! 7*(1,1,...)                    TrialEstimate(1.347 μs)
nprand_pywrapfn 7*(1,1,...)             TrialEstimate(1.394 μs)
nprand_pywrapfn_noargs 7*(1,1,...)      TrialEstimate(1.318 μs)
pycall_legacy 12*(1,1,...)              TrialEstimate(19.129 μs)
pycall 12*(1,1,...)                     TrialEstimate(3.257 μs)
pycall! 12*(1,1,...)                    TrialEstimate(1.750 μs)
_pycall! 12*(1,1,...)                   TrialEstimate(1.758 μs)
nprand_pywrapfn 12*(1,1,...)            TrialEstimate(1.728 μs)
nprand_pywrapfn_noargs 12*(1,1,...)     TrialEstimate(1.560 μs)
pycall_legacy 17*(1,1,...)              TrialEstimate(26.639 μs)
pycall 17*(1,1,...)                     TrialEstimate(3.360 μs)
pycall! 17*(1,1,...)                    TrialEstimate(2.088 μs)
_pycall! 17*(1,1,...)                   TrialEstimate(2.020 μs)
nprand_pywrapfn 17*(1,1,...)            TrialEstimate(2.234 μs)
nprand_pywrapfn_noargs 17*(1,1,...)     TrialEstimate(1.609 μs)

这里用我最近写的一个例子来说明下实际使用中性能的差异:

下面是用Python调用OpenAI Gym的库:

import time
import numpy as np
def test():
    env = gym.make('CartPole-v0')
    env.reset()
    i = 0
    state, reward, done, info = env.step(env.action_space.sample())
    t = time.time()
    while not done:
        i += 1
        state, reward, done, info = env.step(env.action_space.sample())
    return (time.time() - t) / i

ts = [test() for _ in range(1000)]
np.mean(ts), np.std(ts) * 1000000
# np.mean(ts), np.std(ts)
# (1.6214491620061298e-05, 4.419731630587962e-06)

每个action平均耗时大约16毫秒左右。
下面先使用普通的PyCall代码实现下:

Version 1

using PyCall

@pyimport gym
function act1()
    pygymenv = gym.make("CartPole-v0")
    pygymenv[:reset]()
    i = 0
    state, reward, done, info = pygymenv[:step](pygymenv[:action_space][:sample]())
    t = @elapsed while !done
        i += 1
        state, reward, done, info = pygymenv[:step](pygymenv[:action_space][:sample]())
    end
    t / i
end

ts = [act1() for _ in 1:1000]
mean(ts), std(ts)
# (0.0003923129730181138, 0.0002608349108703414)

en,平均390ms,差了将近20多倍,跟benchmark的测试结果基本一致。

然后用pycall实现下:

Version 2

function act2()
    pygymenv = gym.make("CartPole-v0")
    pygymenv[:reset]()
    i = 0
    state, reward, done, info = pygymenv[:step](pygymenv[:action_space][:sample]())
    t = @elapsed while !done
        i += 1
        state, reward, done, info = pycall(pygymenv[:step], PyVector, pycall(pygymenv[:action_space][:sample], PyObject))
    end
    t / i
end

ts = [act2() for _ in 1:1000]
mean(ts), std(ts)
(0.0003203265628935604, 0.00017397040461608834)

并没有太大变化。
接下来换pycall!:

Version 3

function act3()
    temp = PyNULL()
    pygymenv = gym.make("CartPole-v0")
    pygymenv[:reset]()
    i = 0
    state, reward, done, info = pygymenv[:step](pygymenv[:action_space][:sample]())
    t = @elapsed while !done
        i += 1
        pycall!(temp, pygymenv[:step], PyVector, pycall(pygymenv[:action_space][:sample], PyObject))
        state, reward, done, info = temp
    end
    t / i
end

ts = [act3() for _ in 1:1000]
mean(ts), std(ts)
(0.00034551712150827387, 0.00032118476732335505)

意不意外?并没有性能上的提升,仔细分析之后发现,pycall!之所以能提高效率,是因为减少了参数的构建过程。而上面的代码中state, reward, done, info = temp又一次显式地构建了局部变量,因而可以进一步减少其开销:

Version 4

function act4()
    temp = PyNULL()
    pygymenv = gym.make("CartPole-v0")
    pygymenv[:reset]()
    i = 0
    pycall!(temp, pygymenv[:step], PyVector, pycall(pygymenv[:action_space][:sample], PyObject))
    t = @elapsed while !temp[3]
        i += 1
        pycall!(temp, pygymenv[:step], PyVector, pycall(pygymenv[:action_space][:sample], PyObject))
    end
    t / i
end

ts = [act4() for _ in 1:1000]
mean(ts), std(ts)
(4.187357851583132e-5, 3.8224473318217355e-5)

嗯,这样才跟Python代码相差2~3倍,基本够用了。。。