XLA Interface¶
To boost the efficiency of the overall system, we introduce the XLA API for envpool. With this API, we can just-in-time compile the environment and agent steps together, when the agent part is implemented with Jax/Tensorflow.
The full example is at https://github.com/sail-sg/envpool/blob/main/examples/xla_step.py
Stateless functions¶
The main issue with jitting the environment is that the env.step(action) -> state
(and similarly the recv/send
) function is not a pure function,
i.e. it changes the state of the underlying env
.
To overcome this issue, we introduce a pure functional version of step
(recv/send
).
Namely, the XLA version of step/recv/send
has the follow signature:
step(envpool_handle: Handle, action: Action) -> Tuple[Handle, State]
recv(envpool_handle: Handle) -> Tuple[Handle, State]
send(envpool_handle: Handle, action: Action) -> Handle
These functions can be obtained from the envpool instance which we created from the Python API.
env = envpool.make(..., env_type="gym" | "dm" | "gymnasium")
handle, recv, send, step = env.xla()
Example of Actor Loop¶
We can now write the actor loop as:
def actor_step(iter, loop_var):
handle0, states = loop_var
action = policy(states)
# for gym < 0.26
handle1, (new_states, rew, done, info) = step(handle0, action)
# for gym >= 0.26
# handle1, (new_states, rew, term, trunc, info) = step(handle0, action)
# for dm
# handle1, new_states = step(handle0, action)
return (handle1, new_states)
@jit
def run_actor_loop(num_steps, init_var):
return lax.fori_loop(0, num_steps, actor_step, init_var)
states = env.reset()
run_actor_loop(100, (handle, states))
Or, with the asynchronous api:
def actor_step(iter, handle):
handle0 = handle
handle1, states = recv(handle0)
action = policy(states.observation.obs)
handle2 = send(handle0, action, states.observation.env_id)
return handle2
@jit
def run_actor_loop(num_steps):
return lax.fori_loop(0, num_steps, actor_step, handle)
env.async_reset()
run_actor_loop(100)
It is also possible to overlap send
and recv
:
def actor_step(iter, loop_var):
handle0, states = loop_var
action = policy(states.observation.obs)
handle1 = send(handle0, action, states.observation.env_id)
handle1, new_states = recv(handle0)
return handle1, new_states
@jit
def run_actor_loop(num_steps, init_var):
return lax.fori_loop(0, num_steps, actor_step, init_var)
env.async_reset()
handle, states = recv(handle)
run_actor_loop(100, (handle, states))
In the above case, recv
is using handle0
, which means policy
and
recv
will be overlapped in each iteration.