[codex] add TorchRL Flame DQN example#453
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new reinforcement learning example, examples/rl/torchrl_dqn, which adapts the TorchRL CartPole DQN tutorial to the Flame Runner framework. The implementation includes a distributed rollout collector, a sharded replay buffer using Flame's ObjectRef and patch_object capabilities, and a local training mode for validation. Feedback on the code changes highlights opportunities to improve performance by moving imports out of the hot collection loop and suggests removing an unused helper function in the main entry point.
| import random | ||
|
|
||
| import torch | ||
| from model import flatten_observation | ||
| from tensordict import TensorDict |
There was a problem hiding this comment.
These imports are performed inside the _select_action method, which is called for every environment step (e.g., 100 times per collection). Although Python caches imports in sys.modules, repeated lookups in a performance-sensitive reinforcement learning loop add unnecessary overhead. Moving these imports to the top of the file is more efficient and adheres to PEP 8 guidelines.
References
- Imports are always put at the top of the file, just after any module comments and docstrings, and before module globals and constants. (link)
| def _sample_request_sizes(batch_size: int, sample_parallelism: int) -> list[int]: | ||
| if sample_parallelism < 1: | ||
| raise ValueError("sample_parallelism must be at least 1") | ||
| if batch_size < 1: | ||
| raise ValueError("batch_size must be at least 1") | ||
|
|
||
| return split_batch(batch_size, sample_parallelism) |
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
46529e8 to
7f8e3f1
Compare
7f8e3f1 to
fb06d43
Compare
Summary
examples/rl/torchrl_dqnbased on the upstream TorchRL tutorial loop.patch_object.Validation
python3 -m py_compile examples/rl/torchrl_dqn/main.py examples/rl/torchrl_dqn/model.py examples/rl/torchrl_dqn/collector.py examples/rl/torchrl_dqn/replay_buffer.pypython3 examples/rl/torchrl_dqn/main.py --helpsdk/python/.venv/bin/ruff check examples/rl/torchrl_dqnsdk/python/.venv/bin/ruff format --check examples/rl/torchrl_dqnuv run main.py --local --env acrobot --iterations 1 --collections 1 --frames-per-collection 2 --batch-size 1 --warmup-frames 1 --replay simple --metrics-json /tmp/torchrl-main-simple-smoke.jsonuv run main.py --local --env acrobot --iterations 1 --collections 1 --frames-per-collection 2 --batch-size 1 --warmup-frames 1 --replay sharded --replay-shards 2 --sample-work 8 --sample-parallelism 2 --metrics-json /tmp/torchrl-main-sharded-smoke.jsongit diff --checkDistributed Flame runtime smoke was not run because it requires an active Flame cluster.