Add New Environment into EnvPool
To add a new environment in C++ that EnvPool will parallelly run, we provide a developer interface in envpool/core/env.h. The generated reference for the core headers used below is available in C++ API Reference.
For a quick and annotated example, please refer to envpool/dummy/.
envpool/atari serves as a more complex, real example.
In the following example, we will create an environment CartPole.
It follows the same CartPole dynamics as Gymnasium.
The full implementation is in Pull Request 25. Let’s go through the details step by step!
Setup File Structure
The first thing is to fork the project and add the new environment in the
envpool folder, i.e., create a classic_control folder under envpool/:
cd envpool
mkdir -p classic_control
Here is the typical file structure:
$ tree classic_control
classic_control
├── BUILD
├── cartpole.h
├── classic_control.cc
├── classic_control_test.py
├── __init__.py
└── registration.py
and their functionalities:
__init__.py: to make this directory a python package;BUILD: to indicate the file dependency (because we use Bazel to manage this project);cartpole.h: the CartPole environment;classic_control.cc: packclassic_control_envpool.sovia pybind11;classic_control_test.py: a simple unit-test to check if we implement correctly;registration.py: registerCartPole-v0andCartPole-v1so that we can useenvpool.make("CartPole-v0")to create an environment.
Implement CartPole Environment in cartpole.h
First, include the core header files:
#include "envpool/core/async_envpool.h"
#include "envpool/core/env.h"
CartPoleEnvSpec
Next, we need create CartPoleEnvSpec to define the env-specific config,
state space, and action space. Create a class CartPoleEnvFns:
// env-specific definition of config and state/action spec
class CartPoleEnvFns {
public:
static decltype(auto) DefaultConfig() {
return MakeDict("reward_threshold"_.Bind(195.0));
}
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
float fmax = std::numeric_limits<float>::max();
return MakeDict("obs"_.Bind(
Spec<float>({4}, {{-4.8, -fmax, -M_PI / 7.5, -fmax},
{4.8, fmax, M_PI / 7.5, fmax}})));
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
// the last argument in Spec is for the range definition
return MakeDict("action"_.Bind(Spec<int>({-1}, {0, 1})));
}
};
// this line will concat common config and common state/action spec
using CartPoleEnvSpec = EnvSpec<CartPoleEnvFns>;
DefaultConfig: the default config to create cartpole environment;StateSpec: the state space (including observation and info) definition;ActionSpec: the action space definition.
CartPole is quite a simple environment. The observation is a NumPy array with
shape (4,), and the action is discrete [0, 1]. This definition is also
available to see on the python side:
>>> import envpool
>>> spec = envpool.make_spec("CartPole-v0")
>>> spec
CartPoleEnvSpec(num_envs=1, batch_size=1, num_threads=0, max_num_players=1, thread_affinity_offset=-1, base_path='envpool', seed=42, gym_reset_return_info=True, max_episode_steps=200, reward_threshold=195.0)
>>> # if we change a config value
>>> env = envpool.make_gym("CartPole-v0", reward_threshold=666)
>>> env
CartPoleGymnasiumEnvPool(num_envs=1, batch_size=1, num_threads=0, max_num_players=1, thread_affinity_offset=-1, base_path='envpool', seed=42, gym_reset_return_info=True, max_episode_steps=200, reward_threshold=666.0)
>>> # observation space and action space
>>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)
>>> env.action_space
Discrete(2)
>>> env.spec.reward_threshold
666.0
Danger
When using a string in MakeDict, you should explicitly use
std::string. For example,
auto config = MakeDict("path"_.Bind("init_path"));
The type of “path” will be a const char * type instead of
std::string, which sometimes causes config["path"_] to be a
meaningless string in further usage. Instead, you should change the code as
auto config = MakeDict("path"_.Bind(std::string("init_path")));
Note
The above example shows how to define a discrete action space by specifying
the last argument of Spec. Here is another example, if our environment
has 6 actions, ranging from 0 to 5:
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
return MakeDict("action"_.Bind(Spec<int>({-1}, {0, 5})));
// or remove -1, no difference in single-player env
// return MakeDict("action"_.Bind(Spec<int>({}, {0, 5})));
}
For continuous action space, change the type of Spec to float or
double. For example, if the action is a NumPy array with four floats,
ranging from -2 to 2:
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
return MakeDict("action"_.Bind(Spec<float>({-1, 4}, {-2.0, 2.0})));
// or remove -1, no difference in single-player env
// return MakeDict("action"_.Bind(Spec<float>({4}, {-2.0, 2.0})));
}
Note
-1 in Spec is reserved for the number of players. In single-player
environment, Spec<int>({-1}) is the same as Spec<int>({}) (empty
shape), but in a multi-player environment, empty shape spec will be only a
single int value per environment, while the former will be an array with
length == #players (can be 0 when all players are dead).
Note
The common config and common state/action spec are defined in env_spec.h.
Note
EnvPool supports the environment that has multiple observations or even
nested observations. For example, FetchReach-v4:
>>> import envpool
>>> env = envpool.make_gymnasium("FetchReach-v4", num_envs=1, seed=0)
>>> env.observation_space
Dict(achieved_goal:Box([-inf ...], [inf ...], (3,), float32), desired_goal:Box([-inf ...], [inf ...], (3,), float32), observation:Box([-inf ...], [inf ...], (10,), float32))
>>> obs, info = env.reset()
>>> obs["observation"].shape
(1, 10)
>>> obs["achieved_goal"].shape
(1, 3)
If we want to create such a state spec (including both obs and info), here is the solution:
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict(
"obs:observation"_.Bind(Spec<float>({10})),
"obs:achieved_goal"_.Bind(Spec<float>({3})),
"obs:desired_goal"_.Bind(Spec<float>({3})),
"info:is_success"_.Bind(Spec<float>({})));
}
The keys start with obs: will be parsed to obs dict, and similarly
info: will be parsed to info dict.
For nested observations such as {"obs_a": {"obs_b": 6}}, use . to
indicate the hierarchy:
return MakeDict("obs:obs_a.obs_b"_.Bind(Spec<int>({})));
It is the same as ActionSpec. The only difference is: there’s no obs:
and info: in action.
Note
In dm_env, keys in Spec that start with either obs: or info: will
be merged under timestep.observation.
Note
To create a dynamic shape array (which will be converted into a numpy
array with object type), you can use Spec<Container<...>>, e.g.:
"info:id_list"_.Bind(Spec<int>({-1})),
CartPoleEnv
Now we are going to create a class CartPoleEnv that inherits
Env.
We have already defined three types Spec, State and Action in Env
class for convenience, which follows the definition of CartPoleEnvSpec.
The following functions are required to override:
constructor, in this case it is
CartPoleEnv(const Spec& spec, int env_id); you can usespec.config["max_episode_steps"_]to extract the value from config;bool IsDone(): return a boolean that indicate whether the current episode is finished or not;void Reset(): perform oneenv.reset();void Step(const Action& action): perform oneenv.step(action).
The reference implementation is in envpool/classic_control/cartpole.h.
Array Read/Write
State and Action are dict-style data structures for easier prototyping.
All values in these dictionaries are with type Array, which mimic the
functionality of a multi-dimensional array.
To extract value from action in CartPoleEnv:
// auto convert the first element in action["action"_]
int act = action["action"_];
// for continuous action space, e.g.
// float act2 = action["action"_][2];
If the state/action contains several keys and each element is a multi-dimensional array, e.g., an image, there are three ways to deal with array read/write:
uint8_t *ptr = static_cast<uint8_t *>(state["obs"_].data());
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 84; ++j) {
for (int k = 0; k < 84; ++k) {
// 1. use []
state["obs"_][i][j][k] = ...
// 2. use (), faster than 1
state["obs"_](i, j, k) = ...
// 3. use raw pointer
ptr[i * 84 * 84 + j * 84 + k] = ...
}
}
}
If one of the array for state is dynamic-shaped:
Container<int>& dyn = state["obs:dyn"_][i];
// new spec
auto dyn_spec = ::Spec<int>({env_id_ + 1, spec_.config["state_num"_]});
// use this spec to create an array
auto* array = new TArray<int>(dyn_spec);
// perform some normal array writing
// finally pass it to dynamic array
dyn.reset(array);
Allocate State in Reset and Step
EnvPool has carefully designed the data movement to achieve zero-copy with the lowest overhead. We create a simple API to make it be more user-friendly.
At the end of Reset and Step function, you need to call Allocate
method to allocate state for writing. For example, in CartPoleEnv:
State state = Allocate();
state["obs"_][0] = static_cast<float>(x_);
state["obs"_][1] = static_cast<float>(x_dot_);
state["obs"_][2] = static_cast<float>(theta_);
state["obs"_][3] = static_cast<float>(theta_dot_);
state["reward"_] = 1.0f;
// here is a buggy usage because x_ is float64 and state["obs"_] is float32
// state["obs"_][0] = x_;
You do not pass this state to any other functions or return. Instead, AsyncEnvPool will automatically process the data and pack it to the python interface.
Note
For multi-player environments, you need to allocate state with an extra
argument player_num. For example, if the state spec is:
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict(
"obs:players.obs"_.Bind(Spec<uint8_t>({-1, 4, 84, 84})),
"obs:players.location"_.Bind(Spec<uint8_t>({-1, 2})),
"info:players.health"_.Bind(Spec<int>({-1})),
"info:player_num"_.Bind(Spec<int>({})),
"info:bla"_.Bind(Spec<float>({2, 3, 3})),
"info:list"_.Bind(Spec<Container<float>>({-1}))
);
}
By calling auto state = Allocate(10), the state would be like:
state["obs:players.obs"_]; // shape: (10, 4, 84, 84)
state["obs:players.location"]; // shape: (10, 2)
state["info:players.health"]; // shape: (10,)
state["info:player_num"]; // shape: (), only one element
state["info:bla"]; // shape: (2, 3, 3)
state["info:list"]; // shape: (10,) with dtype=object
Danger
Please make sure the types are correct. Assigning int to a float array or
assigning double to an uint64_t array will not generate any compilation
error, but in the actual runtime, the data is wrong. Please use
static_cast to convert the type correctly.
CartPoleEnvPool
After creating CartPoleEnv, just one more line we can get
CartPoleEnvPool:
using CartPoleEnvPool = AsyncEnvPool<CartPoleEnv>;
Miscellaneous
Note
Please do not use the pseudo-random number by rand() % MAX. Instead,
use random number distributions to generate
thread-safe deterministic pseudo-random numbers. std::mt19937 generator
has already been defined as gen_ (link).
Note
ENVPOOL_TEST is a test-time macro. If you want a piece of C++ code only
available during unit test:
#ifdef ENVPOOL_TEST
fprintf(stderr, "here\n");
LOG(INFO) << "another error log print method.";
#endif
Generate Dynamic Linked .so File and Instantiate in Python
We use pybind11 to let python interface use this C++ code. We have already wrapped this interface, and you need to add only a few lines to make it work:
#include "envpool/classic_control/cartpole.h"
#include "envpool/core/py_envpool.h"
// generate python-side (raw) CartPoleEnvSpec
using CartPoleEnvSpec = PyEnvSpec<classic_control::CartPoleEnvSpec>;
// generate python-side (raw) CartPoleEnvPool
using CartPoleEnvPool = PyEnvPool<classic_control::CartPoleEnvPool>;
// generate classic_control_envpool.so
PYBIND11_MODULE(classic_control_envpool, m) {
REGISTER(m, CartPoleEnvSpec, CartPoleEnvPool)
}
After that, you can import _CartPoleEnvSpec and _CartPoleEnvPool from
classic_control_envpool.so.
The next step is to apply python-side wrapper (gymnasium/dm_env APIs) to raw classes.
In envpool/classic_control/__init__.py, use py_env function to
instantiate CartPoleEnvSpec, CartPoleDMEnvPool,
and CartPoleGymnasiumEnvPool.
from envpool.python.api import py_env
from .classic_control_envpool import _CartPoleEnvPool, _CartPoleEnvSpec
(
CartPoleEnvSpec,
CartPoleDMEnvPool,
CartPoleGymnasiumEnvPool,
) = py_env(_CartPoleEnvSpec, _CartPoleEnvPool)
__all__ = [
"CartPoleEnvSpec",
"CartPoleDMEnvPool",
"CartPoleGymnasiumEnvPool",
]
Register CartPole-v0/1 in EnvPool
To register a task in EnvPool, you need to call register function in
envpool.registration. Here is registration.py:
from envpool.registration import register
register(
task_id="CartPole-v0",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gymnasium_cls="CartPoleGymnasiumEnvPool",
max_episode_steps=200,
reward_threshold=195.0,
)
register(
task_id="CartPole-v1",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gymnasium_cls="CartPoleGymnasiumEnvPool",
max_episode_steps=500,
reward_threshold=475.0,
)
task_id, import_path, spec_cls, dm_cls and gymnasium_cls
are required arguments. Other arguments such as
max_episode_steps and reward_threshold are env-specific. For example,
if someone use envpool.make("CartPole-v1"), the reward_threshold will
be set to 475.0 at CartPoleEnvPool initialization.
Finally, it is crucial to let the top-level module import this file. In
envpool/entry.py, add the following line:
import envpool.classic_control.registration # noqa: F401
Write Bazel BUILD File
Bazel is a powerful tool to build and test C++-based projects. Python projects can also apply it. Bazel manages all files in EnvPool.
There are some tutorials for Bazel, but for convenience, we only demonstrate the key point here when using Bazel in this project, i.e., how to write BUILD correctly.
Bazel Header
Most of the time, directly include the following things at the top of BUILD:
load("@pip_requirements//:requirements.bzl", "requirement")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
package(default_visibility = ["//visibility:public"])
Types of Rules
cc_library: C++ header file*.h, usually for environment definition. Required fields:name,hdrs;cc_test: C++ source file*.ccfor running C++ unit tests. Required fields:name,srcs;pybind_extension: C++ source file*.ccto generate.sofile with{name}.so. Required fields:name,srcs;py_library: Python library file*.py. Required fields:name,srcs;py_test: Python file*.pyfor running Python unit tests. Required fields:name,srcs.
All of the above declarations can have deps and data fields, which
explicitly specify the dependencies of either a Bazel BUILD rule or a
third-party data. We will explain deps in the next section.
If you are looking for other functionalities like gen_rules, please refer
to Third-party Dependencies.
deps
Let’s first take a look at BUILD file in classic_control:
load("@pip_requirements//:requirements.bzl", "requirement")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
package(default_visibility = ["//visibility:public"])
cc_library(
name = "cartpole",
hdrs = ["cartpole.h"],
deps = [
"//envpool/core:async_envpool",
],
)
pybind_extension(
name = "classic_control_envpool",
srcs = [
"classic_control.cc",
],
deps = [
":cartpole",
"//envpool/core:py_envpool",
],
)
py_library(
name = "classic_control",
srcs = ["__init__.py"],
data = [":classic_control_envpool.so"],
deps = ["//envpool/python:api"],
)
py_library(
name = "classic_control_registration",
srcs = ["registration.py"],
deps = [
"//envpool:registration",
],
)
py_test(
name = "classic_control_test",
srcs = ["classic_control_test.py"],
deps = [
":classic_control",
":classic_control_registration",
requirement("numpy"),
requirement("absl-py"),
],
)
We have several ways for dependency declaration:
use relative path:
:cartpolepoints to first item (cartpole cc_library);use absolute path:
//envpool/core:async_envpoolpoints to async_envpool underenvpool/core;python dependency:
requirement("numpy")means this file use NumPy as runtime dependencies;third-party dependency (not shown above): will explain in the next section.
And don’t forget to modify the top-level Bazel BUILD dependency:
py_library(
name = "entry",
srcs = ["entry.py"],
deps = [
"//envpool/atari:atari_registration",
+ "//envpool/classic_control:classic_control_registration",
],
)
py_library(
name = "envpool",
srcs = ["__init__.py"],
deps = [
":entry",
":registration",
"//envpool/atari",
+ "//envpool/classic_control",
"//envpool/python",
],
)
Also, pay attention to check if .so file is packed into .whl
successfully. In setup.cfg:
[options.package_data]
envpool = atari/*.so
atari/roms/*.bin
+ classic_control/*.so
Now you can run envpool.make("CartPole-v0") by re-installing EnvPool:
# generate .whl file
make bazel-build
# install .whl
pip install dist/envpool-<version>-*.whl
Testing
To test whether the BUILD file is correct for Bazel to compile:
bazel build //envpool/classic_control --config=debug
This command will automatically display the details of the compilation and help make your life easier.
Third-party Dependencies
CartPole environment is so simple that there are no third-party dependencies. However, it is often the case to include some third-party dependencies for a more complex environment.
For third-party Python dependency, for instance, if we want to add tianshou
as test dependency, in third_party/pip_requirements/requirements.txt:
six
tensorboard
+tianshou
torch
tqdm
If we want to add tianshou as a build dependency, in setup.cfg:
[options]
packages = find:
python_requires = >=3.11
install_requires =
dm-env>=1.4
gymnasium>=0.26
numpy>=1.19
types-protobuf>=3.17.3
typing-extensions
+ tianshou
As for source-code dependency, for example, if we want to download
ThreadPool and use it in
//envpool/core:async_envpool, here are the steps to follow:
add download item for ThreadPool in
envpool/workspace0.bzl:
maybe(
http_archive,
name = "threadpool",
sha256 = "18854bb7ecc1fc9d7dda9c798a1ef0c81c2dd331d730c76c75f648189fa0c20f",
strip_prefix = "ThreadPool-9a42ec1329f259a5f4881a291db1dcb8f2ad9040",
urls = [
"https://github.com/progschj/ThreadPool/archive/9a42ec1329f259a5f4881a291db1dcb8f2ad9040.zip",
],
build_file = "//third_party/threadpool:threadpool.BUILD",
)
Here is the reference documentation for http_archive.
add ThreadPool into
third_party/:
mkdir -p third_party/threadpool
touch third_party/threadpool/BUILD
touch third_party/threadpool/threadpool.BUILD
leave BUILD empty, and add the following rules in threadpool.BUILD:
package(default_visibility = ["//visibility:public"])
cc_library(
name = "threadpool",
hdrs = ["ThreadPool.h"],
)
It says ThreadPool.h is exposed on the top level of “threadpool” namespace.
modify Bazel build rules of async_envpool:
cc_library(
name = "async_envpool",
hdrs = ["async_envpool.h"],
deps = [
":action_buffer_queue",
":array",
":env",
":envpool",
":spec",
":state_buffer_queue",
+ "@threadpool",
],
)
The dependency string format is @<package> or @<package>//:<name>.
For genrule() and data = [...], please refer to Bazel official
documentation or
Atari BUILD example.
Add Unit Test for CartPoleEnv
It is highly encouraged to write unit tests to ensure the correctness of the new environment. You can write both Python and C++ tests.
C++ Env Tests
We use GoogleTest to run C++ unit tests. You can reach out to Google Test official documentation to see how to use it.
To enable GoogleTest, you need to modify the corresponding Bazel BUILD rule:
cc_test(
name = "atari_env_test",
srcs = ["atari_env_test.cc"],
deps = [
":atari_env",
+ "@com_google_googletest//:gtest_main",
],
)
Python Env Tests
We use Abseil test to run Python unit tests. To enable, you need to modify the corresponding Bazel BUILD rule:
py_test(
name = "classic_control_test",
srcs = ["classic_control_test.py"],
deps = [
":classic_control",
requirement("numpy"),
+ requirement("absl-py"),
],
)
Make Tests
You can add a test in envpool/make_test.py to see if the environment can be
successfully created.
New Environment Review Checklist
Before opening a PR for a new environment family, make sure the implementation is complete end-to-end:
The runtime implementation is native C++; do not call or embed the official Python environment from C++ runtime code.
Pin the exact upstream oracle version when one exists, and keep all tests anchored to that version.
Register every intended upstream task ID or scenario. Do not collapse multiple upstream IDs into one generic EnvPool task.
Add registry coverage that checks EnvPool task IDs and task configuration values against the pinned upstream source when practical.
Add deterministic tests that replay the same external action sequence across reset plus nontrivial multi-step rollouts for every registered ID. If render is supported, include rendered frames in the determinism check.
Add step-level oracle alignment tests when an official implementation exists. A reset-time state sync is acceptable when needed, but do not sync state again during the rollout. Compare observations, rewards, terminated/truncated semantics, exposed info, and renders when render output is expected to match.
Add render tests for reset frames, multi-step frames, batched rendering, and env-id selection when rendering is supported.
Add the new family to
envpool/make_test.pyso source builds and installed release wheels exerciseenvpool.make_*for the registered import path.Update the environment docs, docs index, README support list, and release packaging. If an official renderer exists, add an EnvPool-vs-official render comparison image to the environment doc page.
Keep tolerances narrow, platform-scoped, and documented. Do not skip new environment tests on a supported platform just to make CI pass.