XLA Interface ============= To boost the efficiency of the overall system, we introduce the XLA API for envpool. With this API, we can just-in-time compile the environment and agent steps together, when the agent part is implemented with Jax/Tensorflow. The full example is at https://github.com/sail-sg/envpool/blob/main/examples/xla_step.py Stateless functions ------------------- The main issue with jitting the environment is that the ``env.step(action) -> state`` (and similarly the ``recv/send``) function is not a pure function, i.e. it changes the state of the underlying ``env``. To overcome this issue, we introduce a pure functional version of ``step`` (``recv/send``). Namely, the XLA version of ``step/recv/send`` has the follow signature: :: step(envpool_handle: Handle, action: Action) -> Tuple[Handle, State] recv(envpool_handle: Handle) -> Tuple[Handle, State] send(envpool_handle: Handle, action: Action) -> Handle These functions can be obtained from the envpool instance which we created from the Python API. :: env = envpool.make(..., env_type="gym" | "dm" | "gymnasium") handle, recv, send, step = env.xla() Example of Actor Loop --------------------- We can now write the actor loop as: :: def actor_step(iter, loop_var): handle0, states = loop_var action = policy(states) # for gym < 0.26 handle1, (new_states, rew, done, info) = step(handle0, action) # for gym >= 0.26 # handle1, (new_states, rew, term, trunc, info) = step(handle0, action) # for dm # handle1, new_states = step(handle0, action) return (handle1, new_states) @jit def run_actor_loop(num_steps, init_var): return lax.fori_loop(0, num_steps, actor_step, init_var) states = env.reset() run_actor_loop(100, (handle, states)) Or, with the asynchronous api: :: def actor_step(iter, handle): handle0 = handle handle1, states = recv(handle0) action = policy(states.observation.obs) handle2 = send(handle0, action, states.observation.env_id) return handle2 @jit def run_actor_loop(num_steps): return lax.fori_loop(0, num_steps, actor_step, handle) env.async_reset() run_actor_loop(100) It is also possible to overlap ``send`` and ``recv``: :: def actor_step(iter, loop_var): handle0, states = loop_var action = policy(states.observation.obs) handle1 = send(handle0, action, states.observation.env_id) handle1, new_states = recv(handle0) return handle1, new_states @jit def run_actor_loop(num_steps, init_var): return lax.fori_loop(0, num_steps, actor_step, init_var) env.async_reset() handle, states = recv(handle) run_actor_loop(100, (handle, states)) In the above case, ``recv`` is using ``handle0``, which means ``policy`` and ``recv`` will be overlapped in each iteration.