From 288847d6c24085bcf73ae4e92a962b1701115f2e Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Thu, 4 Dec 2025 16:59:49 +0000 Subject: [PATCH 01/12] feat: Add Agent0 development tools and documentation Co-authored-by: wbaker7702 --- IMPLEMENTATION_SUMMARY.md | 172 ++++++++++ INDEX.md | 221 +++++++++++++ Makefile | 128 ++++++++ PLAN.md | 632 ++++++++++++++++++++++++++++++++++++ QUICK_REFERENCE.md | 242 ++++++++++++++ scripts/audit_code.sh | 128 ++++++++ scripts/debug_helper.sh | 156 +++++++++ scripts/explore_codebase.sh | 92 ++++++ scripts/run_tests.sh | 92 ++++++ scripts/validate_build.sh | 124 +++++++ 10 files changed, 1987 insertions(+) create mode 100644 IMPLEMENTATION_SUMMARY.md create mode 100644 INDEX.md create mode 100644 Makefile create mode 100644 PLAN.md create mode 100644 QUICK_REFERENCE.md create mode 100755 scripts/audit_code.sh create mode 100755 scripts/debug_helper.sh create mode 100755 scripts/explore_codebase.sh create mode 100755 scripts/run_tests.sh create mode 100755 scripts/validate_build.sh diff --git a/IMPLEMENTATION_SUMMARY.md b/IMPLEMENTATION_SUMMARY.md new file mode 100644 index 0000000..d58a56e --- /dev/null +++ b/IMPLEMENTATION_SUMMARY.md @@ -0,0 +1,172 @@ +# Agent0 Series - Implementation Summary + +## ๐Ÿ“ฆ What Was Created + +This implementation provides a comprehensive plan and tooling for exploring, building, debugging, validating, auditing, deploying, and integrating the Agent0 series codebase. + +### ๐Ÿ“„ Documentation + +1. **PLAN.md** - Comprehensive development and deployment plan covering: + - Explorer: Codebase structure analysis + - Search: Component discovery strategies + - Build: Environment setup and validation + - Debug: Debugging tools and strategies + - Validate: Testing and validation procedures + - Audit: Code quality and security audits + - Deploy: Deployment architectures and procedures + - Integrate: Integration points and APIs + +2. **QUICK_REFERENCE.md** - Quick reference guide for common tasks + +3. **IMPLEMENTATION_SUMMARY.md** - This file, summarizing what was created + +### ๐Ÿ› ๏ธ Scripts + +All scripts are located in `/workspace/scripts/`: + +1. **explore_codebase.sh** - Codebase exploration tool + - Explore overall structure + - Find training components + - Locate tool servers + - Discover evaluation components + - Analyze dependencies + +2. **validate_build.sh** - Build validation script + - Check Python and CUDA availability + - Verify critical packages + - Validate file structure + - Check configuration files + +3. **run_tests.sh** - Test runner + - Unit tests + - Integration tests + - Quick validation tests + +4. **debug_helper.sh** - Debugging utilities + - GPU status monitoring + - Ray cluster status + - Log file checking + - SandboxFusion testing + - vLLM server testing + - Memory profiling + - Configuration validation + +5. **audit_code.sh** - Code audit tool + - Security scanning (Bandit) + - Dependency vulnerability checking (Safety) + - Code quality (Black, Flake8, Pylint) + - Dependency analysis + +### ๐Ÿ”ง Makefile + +Convenient Makefile with common commands: +- `make explore` - Explore codebase +- `make build` - Validate build +- `make test` - Run tests +- `make debug-*` - Various debugging commands +- `make audit` - Run audits +- `make help` - Show all commands + +## ๐Ÿš€ Usage + +### Quick Start + +```bash +# View all available commands +make help + +# Explore the codebase +make explore + +# Validate your build environment +make build + +# Run quick tests +make test-quick + +# Check GPU status +make debug-gpu + +# Run security audit +make audit-security +``` + +### Detailed Usage + +See `QUICK_REFERENCE.md` for detailed usage examples and `PLAN.md` for comprehensive documentation. + +## ๐Ÿ“Š Coverage + +### โœ… Completed + +- [x] Comprehensive planning document +- [x] Codebase exploration tools +- [x] Build validation scripts +- [x] Testing infrastructure +- [x] Debugging utilities +- [x] Code audit tools +- [x] Quick reference guide +- [x] Makefile for easy execution + +### ๐Ÿ”„ Next Steps + +1. **Customize Configuration**: Update scripts with your specific paths and settings +2. **Set Up CI/CD**: Integrate scripts into your CI/CD pipeline +3. **Add Monitoring**: Set up monitoring dashboards for production +4. **Documentation**: Add project-specific documentation +5. **Testing**: Expand test coverage based on your needs + +## ๐ŸŽฏ Key Features + +### Exploration +- Automated codebase structure analysis +- Component discovery +- Dependency mapping + +### Build & Validation +- Environment verification +- Dependency checking +- Configuration validation + +### Debugging +- GPU monitoring +- Ray cluster diagnostics +- Service connectivity testing +- Memory profiling + +### Quality Assurance +- Security scanning +- Code quality checks +- Dependency auditing + +### Deployment +- Deployment readiness checks +- Configuration validation +- Integration testing + +## ๐Ÿ“ Notes + +- All scripts are executable and ready to use +- Scripts include error handling and informative output +- Makefile provides convenient shortcuts +- Documentation is comprehensive and searchable + +## ๐Ÿ”— Related Files + +- `PLAN.md` - Full development and deployment plan +- `QUICK_REFERENCE.md` - Quick reference guide +- `scripts/` - All executable scripts +- `Makefile` - Convenient command shortcuts + +## ๐Ÿ†˜ Support + +For issues or questions: +1. Check `PLAN.md` for detailed documentation +2. Review `QUICK_REFERENCE.md` for common tasks +3. Run `make debug-config` to check your setup +4. Use `make help` to see all available commands + +--- + +*Created: 2025-01-XX* +*Version: 1.0* diff --git a/INDEX.md b/INDEX.md new file mode 100644 index 0000000..56fe566 --- /dev/null +++ b/INDEX.md @@ -0,0 +1,221 @@ +# Agent0 Series - Development Tools Index + +## ๐Ÿ“š Documentation + +### Main Documents +1. **[PLAN.md](./PLAN.md)** - Comprehensive development and deployment plan + - Complete guide covering all aspects: explore, search, build, debug, validate, audit, deploy, integrate + - Detailed procedures and best practices + - Architecture diagrams and configurations + +2. **[QUICK_REFERENCE.md](./QUICK_REFERENCE.md)** - Quick reference guide + - Common commands and workflows + - Troubleshooting tips + - Configuration examples + +3. **[IMPLEMENTATION_SUMMARY.md](./IMPLEMENTATION_SUMMARY.md)** - Implementation summary + - Overview of created tools and scripts + - Usage instructions + - Coverage and next steps + +4. **[README.md](./README.md)** - Project README + - Project overview and features + - Results and benchmarks + - Citation information + +## ๐Ÿ› ๏ธ Tools & Scripts + +### Scripts Directory: `/workspace/scripts/` + +| Script | Purpose | Usage | +|--------|---------|-------| +| `explore_codebase.sh` | Explore codebase structure | `./scripts/explore_codebase.sh [component]` | +| `validate_build.sh` | Validate build environment | `./scripts/validate_build.sh` | +| `run_tests.sh` | Run test suites | `./scripts/run_tests.sh [type]` | +| `debug_helper.sh` | Debugging utilities | `./scripts/debug_helper.sh [command]` | +| `audit_code.sh` | Code quality audits | `./scripts/audit_code.sh [type]` | + +### Makefile Commands + +Use `make help` to see all available commands, or: + +```bash +# Exploration +make explore # Explore codebase +make explore-training # Explore training components +make explore-tools # Explore tool servers + +# Build & Setup +make build # Validate build environment +make install # Install dependencies + +# Testing +make test # Run unit tests +make test-quick # Run quick tests +make validate # Full validation + +# Debugging +make debug-gpu # Check GPU status +make debug-ray # Check Ray cluster +make debug-config # Validate configuration + +# Auditing +make audit # Run all audits +make audit-security # Security audit +make audit-quality # Code quality audit +``` + +## ๐Ÿš€ Quick Start + +### 1. First Time Setup +```bash +# Install dependencies +make install + +# Validate build +make build + +# Explore codebase +make explore +``` + +### 2. Daily Development +```bash +# Quick validation +make test-quick + +# Check status +make debug-config + +# Run tests +make test +``` + +### 3. Before Deployment +```bash +# Full validation +make validate + +# Security audit +make audit-security + +# Deployment check +make deploy-check +``` + +## ๐Ÿ“‹ Workflow Guide + +### Exploration Phase +1. Run `make explore` to understand codebase structure +2. Use `make explore-training` to find training components +3. Check `PLAN.md` Section 1 (Explorer) for detailed analysis + +### Build Phase +1. Run `make build` to validate environment +2. Use `make install` to install dependencies +3. See `PLAN.md` Section 3 (Build) for setup procedures + +### Development Phase +1. Use `make test-quick` for rapid validation +2. Use `make debug-*` commands for troubleshooting +3. Refer to `PLAN.md` Section 4 (Debug) for debugging strategies + +### Validation Phase +1. Run `make test` for unit tests +2. Run `make validate` for full validation +3. See `PLAN.md` Section 5 (Validate) for testing procedures + +### Audit Phase +1. Run `make audit` for comprehensive audit +2. Review security findings +3. See `PLAN.md` Section 6 (Audit) for audit procedures + +### Deployment Phase +1. Run `make deploy-check` for readiness check +2. Review deployment configuration +3. See `PLAN.md` Section 7 (Deploy) for deployment guide + +### Integration Phase +1. Run `make integrate-check` for integration validation +2. Test API endpoints +3. See `PLAN.md` Section 8 (Integrate) for integration guide + +## ๐ŸŽฏ Use Cases + +### I want to understand the codebase +โ†’ Read `PLAN.md` Section 1 (Explorer) +โ†’ Run `make explore` + +### I want to set up my environment +โ†’ Read `PLAN.md` Section 3 (Build) +โ†’ Run `make build` and `make install` + +### I'm having build issues +โ†’ Run `make debug-config` +โ†’ Check `QUICK_REFERENCE.md` Troubleshooting section + +### I want to run tests +โ†’ Read `PLAN.md` Section 5 (Validate) +โ†’ Run `make test` or `make test-quick` + +### I want to check code quality +โ†’ Read `PLAN.md` Section 6 (Audit) +โ†’ Run `make audit` + +### I want to deploy +โ†’ Read `PLAN.md` Section 7 (Deploy) +โ†’ Run `make deploy-check` + +### I need quick help +โ†’ Check `QUICK_REFERENCE.md` +โ†’ Run `make help` + +## ๐Ÿ“Š File Structure + +``` +/workspace/ +โ”œโ”€โ”€ PLAN.md # Comprehensive plan +โ”œโ”€โ”€ QUICK_REFERENCE.md # Quick reference +โ”œโ”€โ”€ IMPLEMENTATION_SUMMARY.md # Implementation summary +โ”œโ”€โ”€ INDEX.md # This file +โ”œโ”€โ”€ Makefile # Convenient commands +โ”œโ”€โ”€ scripts/ # All utility scripts +โ”‚ โ”œโ”€โ”€ explore_codebase.sh +โ”‚ โ”œโ”€โ”€ validate_build.sh +โ”‚ โ”œโ”€โ”€ run_tests.sh +โ”‚ โ”œโ”€โ”€ debug_helper.sh +โ”‚ โ””โ”€โ”€ audit_code.sh +โ”œโ”€โ”€ Agent0/ # Agent0 codebase +โ”‚ โ”œโ”€โ”€ curriculum_train/ +โ”‚ โ”œโ”€โ”€ executor_train/ +โ”‚ โ””โ”€โ”€ requirements.txt +โ””โ”€โ”€ Agent0-VL/ # Agent0-VL codebase + โ””โ”€โ”€ README.md +``` + +## ๐Ÿ”— Related Resources + +- **Project Repository**: [Agent0 GitHub](https://github.com/aiming-lab/Agent0) +- **Agent0 Paper**: [arXiv:2511.16043](https://arxiv.org/abs/2511.16043) +- **Agent0-VL Paper**: [arXiv:2511.19900](https://arxiv.org/abs/2511.19900) +- **Documentation Website**: [Agent0 Website](https://aiming-lab.github.io/Agent0) + +## ๐Ÿ’ก Tips + +1. **Start with exploration**: Use `make explore` to understand the codebase +2. **Validate early**: Run `make build` before starting development +3. **Use quick tests**: `make test-quick` for rapid feedback +4. **Check configuration**: `make debug-config` when things don't work +5. **Read the plan**: `PLAN.md` has detailed procedures for everything + +## ๐Ÿ†˜ Getting Help + +1. **Quick help**: Run `make help` or check `QUICK_REFERENCE.md` +2. **Detailed guide**: Read `PLAN.md` for comprehensive documentation +3. **Troubleshooting**: Use `make debug-*` commands and check logs +4. **Configuration**: Run `make debug-config` to validate setup + +--- + +*Last Updated: 2025-01-XX* +*For the latest information, see the individual documentation files.* diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d2c9551 --- /dev/null +++ b/Makefile @@ -0,0 +1,128 @@ +.PHONY: help explore build validate test debug audit deploy integrate + +help: + @echo "Agent0 Series - Development Commands" + @echo "====================================" + @echo "" + @echo "Exploration & Search:" + @echo " make explore - Explore codebase structure" + @echo " make explore-all - Full codebase exploration" + @echo " make explore-training - Explore training components" + @echo " make explore-tools - Explore tool servers" + @echo "" + @echo "Build & Setup:" + @echo " make build - Validate build environment" + @echo " make install - Install dependencies" + @echo "" + @echo "Testing & Validation:" + @echo " make test - Run unit tests" + @echo " make test-quick - Run quick tests" + @echo " make test-all - Run all tests" + @echo " make validate - Full validation check" + @echo "" + @echo "Debugging:" + @echo " make debug-gpu - Check GPU status" + @echo " make debug-ray - Check Ray cluster" + @echo " make debug-config - Validate configuration" + @echo " make debug-memory - Profile memory usage" + @echo "" + @echo "Code Quality:" + @echo " make audit - Run all audits" + @echo " make audit-security - Security audit" + @echo " make audit-quality - Code quality audit" + @echo "" + @echo "Deployment:" + @echo " make deploy-check - Check deployment readiness" + @echo "" + @echo "Integration:" + @echo " make integrate-check - Check integration points" + @echo "" + +# Exploration +explore: + @bash scripts/explore_codebase.sh all + +explore-all: + @bash scripts/explore_codebase.sh all + +explore-training: + @bash scripts/explore_codebase.sh training + +explore-tools: + @bash scripts/explore_codebase.sh tools + +explore-eval: + @bash scripts/explore_codebase.sh evaluation + +explore-deps: + @bash scripts/explore_codebase.sh dependencies + +# Build +build: + @bash scripts/validate_build.sh + +install: + @echo "Installing dependencies..." + @cd Agent0 && pip install -r requirements.txt + @cd Agent0/curriculum_train && pip install -r requirements.txt + @cd Agent0/executor_train/verl && pip install -e . + @echo "Installing Flash Attention..." + @pip install "flash-attn==2.8.3" --no-build-isolation || echo "Flash Attention installation may require CUDA" + +# Testing +test: + @bash scripts/run_tests.sh unit + +test-quick: + @bash scripts/run_tests.sh quick + +test-all: + @bash scripts/run_tests.sh all + +validate: build test-quick + @echo "โœ… Validation complete" + +# Debugging +debug-gpu: + @bash scripts/debug_helper.sh gpu-status + +debug-ray: + @bash scripts/debug_helper.sh ray-status + +debug-config: + @bash scripts/debug_helper.sh check-config + +debug-memory: + @bash scripts/debug_helper.sh memory-profile + +debug-logs: + @bash scripts/debug_helper.sh check-logs + +debug-sandbox: + @bash scripts/debug_helper.sh test-sandbox + +debug-vllm: + @bash scripts/debug_helper.sh test-vllm + +# Auditing +audit: + @bash scripts/audit_code.sh all + +audit-security: + @bash scripts/audit_code.sh security + +audit-quality: + @bash scripts/audit_code.sh quality + +audit-deps: + @bash scripts/audit_code.sh dependencies + +# Deployment +deploy-check: build test-quick audit-security + @echo "โœ… Deployment readiness check complete" + +# Integration +integrate-check: + @echo "Checking integration points..." + @python3 -c "import torch; import transformers; import ray; print('โœ… Core integrations OK')" || echo "โŒ Integration check failed" + @bash scripts/debug_helper.sh check-config diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..31f1436 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,632 @@ +# Agent0 Series: Comprehensive Development & Deployment Plan + +## ๐Ÿ“‹ Table of Contents +1. [Explorer](#explorer) +2. [Search](#search) +3. [Build](#build) +4. [Debug](#debug) +5. [Validate](#validate) +6. [Audit](#audit) +7. [Deploy](#deploy) +8. [Integrate](#integrate) + +--- + +## ๐Ÿ” Explorer + +### 1.1 Codebase Structure Analysis + +#### Agent0 (Language Agent) +- **Location**: `/workspace/Agent0/` +- **Components**: + - `curriculum_train/`: Curriculum agent training pipeline + - `question_generate/`: Task generation module + - `question_evaluate/`: Task evaluation and filtering + - `scripts/`: Training scripts + - `verl/`: VeRL framework integration + - `executor_train/`: Executor agent training pipeline + - `verl_tool/`: Tool-integrated RL framework + - `eval_service/`: Evaluation API service + - `examples/`: Training examples and configurations + +#### Agent0-VL (Vision-Language Agent) +- **Location**: `/workspace/Agent0-VL/` +- **Status**: Code release coming soon (per README) +- **Components**: Currently documentation only + +### 1.2 Key Dependencies +- **Core ML**: PyTorch 2.7-2.8, Transformers 4.52-4.57 +- **RL Framework**: VeRL (custom), Ray 2.46-2.51 +- **Inference**: vLLM 0.9-0.11, SGLang +- **Tools**: Flash Attention 2.7-2.8, SandboxFusion +- **Monitoring**: WandB, TensorBoard + +### 1.3 Architecture Patterns +- **Co-evolution**: Curriculum Agent โ†” Executor Agent +- **Tool Integration**: Code interpreter, search, vision APIs +- **Multi-turn RL**: ADPO, GRPO, DAPO algorithms +- **Distributed Training**: FSDP, Megatron-LM support + +--- + +## ๐Ÿ”Ž Search + +### 2.1 Component Discovery Strategy + +#### Search Patterns +```bash +# Find all training scripts +find . -name "*train*.sh" -type f + +# Find configuration files +find . -name "*.yaml" -type f | grep -E "(config|train)" + +# Find entry points +grep -r "if __name__" --include="*.py" + +# Find API endpoints +grep -r "@app\." --include="*.py" +grep -r "FastAPI\|Flask" --include="*.py" +``` + +#### Key Components to Locate +1. **Training Entry Points**: + - `curriculum_train/scripts/curriculum_train.sh` + - `executor_train/examples/train/math_tir/train_qwen3_4b_adpo.sh` + - `executor_train/verl_tool/trainer/main.py` + +2. **Evaluation Services**: + - `executor_train/eval_service/` + - `curriculum_train/question_evaluate/evaluate.py` + +3. **Tool Servers**: + - `executor_train/verl_tool/servers/` + - SandboxFusion integration points + +4. **Model Checkpoints**: + - Checkpoint managers: `verl/utils/checkpoint/` + - Model merging: `curriculum_train/scripts/model_merger.py` + +### 2.2 Dependency Mapping +- **External Services**: SandboxFusion, vLLM servers, Ray cluster +- **Model Sources**: HuggingFace (Qwen models) +- **Storage**: WandB, local filesystem, S3 (via boto3) + +--- + +## ๐Ÿ—๏ธ Build + +### 3.1 Environment Setup + +#### Prerequisites +```bash +# System Requirements +- CUDA 12.x compatible GPUs +- Python 3.8+ +- CUDA toolkit 12.x +- NCCL for distributed training +``` + +#### Installation Steps + +**Step 1: Base Environment** +```bash +cd /workspace/Agent0/Agent0 + +# Install base requirements +pip install -r requirements.txt + +# Install VeRL framework +pip install -e verl + +# Install Flash Attention (requires CUDA) +pip install "flash-attn==2.8.3" --no-build-isolation +``` + +**Step 2: Curriculum Training Setup** +```bash +cd curriculum_train/ +pip install -r requirements.txt +``` + +**Step 3: Executor Training Setup** +```bash +cd executor_train/ +pip install -e verl +pip install -e verl_tool +``` + +### 3.2 External Service Setup + +#### SandboxFusion Service +```bash +# Clone and setup SandboxFusion +git clone https://github.com/bytedance/SandboxFusion.git +cd SandboxFusion +poetry install +make run-online + +# Configure in Agent0 +# Edit: curriculum_train/vllm_service_init/start_vllm_server_tool.py +# Lines 36-41: Add sandbox API URLs +``` + +#### vLLM Server Initialization +```bash +cd curriculum_train/vllm_service_init/ +bash start.sh +``` + +### 3.3 Build Validation +```bash +# Verify installations +python -c "import torch; print(torch.__version__)" +python -c "import flash_attn; print('Flash Attention OK')" +python -c "import ray; print(ray.__version__)" +python -c "import vllm; print(vllm.__version__)" + +# Test VeRL installation +cd executor_train/verl +python -m pytest tests/ -v -k "test_basic" --tb=short +``` + +--- + +## ๐Ÿ› Debug + +### 4.1 Debugging Tools & Strategies + +#### Logging Infrastructure +- **WandB**: Training metrics and visualization +- **TensorBoard**: Local training logs +- **Python Logging**: Structured logging via `verl/utils/logger/` + +#### Debug Configuration +```python +# Enable debug mode in training scripts +export DEBUG=1 +export LOG_LEVEL=DEBUG + +# Ray debugging +export RAY_BACKEND_LOG_LEVEL=debug +``` + +#### Common Debug Scenarios + +**1. CUDA Memory Issues** +```bash +# Monitor GPU memory +watch -n 1 nvidia-smi + +# Reduce batch size in config files +# Look for: batch_size, micro_batch_size, gradient_accumulation_steps +``` + +**2. Distributed Training Issues** +```bash +# Test Ray cluster +ray status + +# Check worker connectivity +python -c "import ray; ray.init(); print(ray.nodes())" +``` + +**3. Tool Execution Failures** +```bash +# Test SandboxFusion connection +curl -X POST http://SANDBOX_IP:PORT/run_code \ + -H "Content-Type: application/json" \ + -d '{"code": "print(1+1)"}' + +# Check tool server logs +tail -f verl_tool/servers/logs/*.log +``` + +**4. Model Loading Issues** +```bash +# Verify model access +python -c "from transformers import AutoModel; AutoModel.from_pretrained('Qwen/Qwen3-4B-Base')" + +# Check checkpoint integrity +python curriculum_train/scripts/model_merger.py --check-only +``` + +### 4.2 Debugging Scripts +- **Profile Training**: Use `py-spy` for performance profiling +- **Trace Execution**: Enable detailed logging in `verl/utils/logger/` +- **Memory Profiling**: Use `torch.profiler` for memory analysis + +--- + +## โœ… Validate + +### 5.1 Testing Strategy + +#### Unit Tests +```bash +# Run VeRL unit tests +cd executor_train/verl +pytest tests/ -v + +# Run tool server tests +cd verl_tool/servers/tests/ +pytest test_*.py -v + +# Run evaluation service tests +cd executor_train/eval_service/test/ +pytest test_api.py -v +``` + +#### Integration Tests +```bash +# Test curriculum training pipeline +cd curriculum_train/ +bash scripts/curriculum_train.sh Qwen/Qwen3-4B-Base Qwen/Qwen3-4B-Base test_run --dry-run + +# Test executor training (small scale) +cd executor_train/ +bash examples/train/math_tir/train_qwen3_4b_adpo.sh --test-mode +``` + +#### End-to-End Validation +```bash +# Full pipeline test (requires GPU) +# 1. Train curriculum agent (1 iteration) +# 2. Generate questions +# 3. Evaluate questions +# 4. Train executor agent (1 step) +# 5. Validate checkpoint +``` + +### 5.2 Benchmark Validation + +#### Mathematical Reasoning Benchmarks +- **MATH**: Verify accuracy > 78% +- **GSM8K**: Verify accuracy > 89% +- **AMC**: Verify accuracy > 52% + +#### General Reasoning Benchmarks +- **MMLU-Pro**: Verify accuracy > 51% +- **SuperGPQA**: Verify accuracy > 28% + +### 5.3 CI/CD Validation +- **Pre-commit**: Code formatting and linting +- **GitHub Actions**: Automated testing (see `.github/workflows/`) +- **Type Checking**: mypy validation +- **Security Scanning**: Dependabot, secret scanning + +--- + +## ๐Ÿ”’ Audit + +### 5.1 Code Quality Audit + +#### Static Analysis +```bash +# Install audit tools +pip install pylint black flake8 mypy bandit safety + +# Code formatting check +black --check --diff . + +# Linting +flake8 . --max-line-length=120 --exclude=venv,__pycache__ + +# Type checking +mypy . --ignore-missing-imports + +# Security audit +bandit -r . -ll +safety check +``` + +#### Code Review Checklist +- [ ] Security: No hardcoded credentials +- [ ] Performance: Efficient data loading and batching +- [ ] Error Handling: Proper exception handling +- [ ] Documentation: Docstrings for public APIs +- [ ] Testing: Unit tests for critical paths + +### 5.2 Dependency Audit + +#### Vulnerability Scanning +```bash +# Check for known vulnerabilities +pip install pip-audit +pip-audit + +# Update dependencies +pip list --outdated +``` + +#### License Compliance +- Verify all dependencies are compatible with Apache 2.0 +- Check for GPL dependencies that may require disclosure + +### 5.3 Performance Audit + +#### Training Efficiency +- Monitor GPU utilization (target: >80%) +- Check for data loading bottlenecks +- Verify distributed training scaling + +#### Memory Audit +- Profile memory usage during training +- Check for memory leaks in long-running processes +- Optimize batch sizes for available hardware + +--- + +## ๐Ÿš€ Deploy + +### 6.1 Deployment Architecture + +#### Training Deployment +``` +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Ray Cluster โ”‚ +โ”‚ (Controller) โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + โ”Œโ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ” + โ”‚ โ”‚ +โ”Œโ”€โ”€โ”€โ–ผโ”€โ”€โ”€โ” โ”Œโ”€โ”€โ–ผโ”€โ”€โ”€โ”€โ” +โ”‚Actor โ”‚ โ”‚Critic โ”‚ +โ”‚Worker โ”‚ โ”‚Worker โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +#### Service Deployment +- **vLLM Servers**: Model inference endpoints +- **SandboxFusion**: Code execution sandboxes +- **Evaluation API**: `executor_train/eval_service/` + +### 6.2 Deployment Configurations + +#### Development Environment +```bash +# Single GPU, local Ray +export RAY_ADDRESS="" +ray start --head + +# Local vLLM server +cd curriculum_train/vllm_service_init/ +bash start.sh +``` + +#### Production Environment +```bash +# Multi-node Ray cluster +ray start --head --port=6379 +# On worker nodes: +ray start --address=HEAD_NODE_IP:6379 + +# Distributed vLLM (multiple GPUs) +# Configure in vllm_service_init/start_vllm_server_tool.py +``` + +### 6.3 Containerization (Future) + +#### Docker Setup +```dockerfile +# Base image with CUDA +FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04 + +# Install Python and dependencies +RUN apt-get update && apt-get install -y python3.10 python3-pip +COPY requirements.txt . +RUN pip install -r requirements.txt + +# Install Flash Attention +RUN pip install flash-attn==2.8.3 --no-build-isolation + +# Copy application +COPY . /app +WORKDIR /app +``` + +#### Kubernetes Deployment +- **Training Jobs**: Kubernetes Jobs for training runs +- **Services**: Deployments for inference and evaluation APIs +- **Storage**: Persistent volumes for checkpoints and data + +### 6.4 Monitoring & Observability + +#### Metrics Collection +- **WandB**: Training metrics, hyperparameters +- **Prometheus**: System metrics (via prometheus-fastapi-instrumentator) +- **Ray Dashboard**: Distributed training monitoring + +#### Logging +- Centralized logging for all services +- Structured JSON logs for parsing +- Log aggregation (ELK stack or similar) + +--- + +## ๐Ÿ”— Integrate + +### 7.1 Integration Points + +#### 1. Model Integration +```python +# Load trained Agent0 model +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_path = "path/to/agent0/checkpoint" +model = AutoModelForCausalLM.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) +``` + +#### 2. Tool Integration +```python +# Use tool-integrated reasoning +from verl_tool.servers import SandboxFusionTool + +tool = SandboxFusionTool(config=config) +result = tool.execute(code="print(1+1)") +``` + +#### 3. Evaluation Integration +```python +# Use evaluation service +from eval_service import EvaluationAPI + +api = EvaluationAPI(endpoint="http://eval-service:8000") +score = api.evaluate(model_output, ground_truth) +``` + +### 7.2 External System Integration + +#### HuggingFace Hub +```python +# Upload checkpoints +from huggingface_hub import HfApi + +api = HfApi() +api.upload_folder( + folder_path="checkpoints/agent0", + repo_id="username/agent0-model", + repo_type="model" +) +``` + +#### WandB Integration +```python +# Logging to WandB +import wandb + +wandb.init(project="agent0-training") +wandb.log({"metric": value}) +``` + +#### Ray Integration +```python +# Distributed training with Ray +import ray + +@ray.remote +def train_worker(config): + # Training logic + pass + +ray.init() +futures = [train_worker.remote(config) for _ in range(num_workers)] +results = ray.get(futures) +``` + +### 7.3 API Integration + +#### Evaluation API +```bash +# Start evaluation service +cd executor_train/eval_service/ +bash scripts/start_api_service.sh + +# API endpoints +POST /evaluate - Evaluate model outputs +GET /health - Health check +GET /metrics - Prometheus metrics +``` + +#### Model Serving API +```python +# vLLM OpenAI-compatible API +from vllm import LLM, SamplingParams + +llm = LLM(model="path/to/model") +sampling_params = SamplingParams(temperature=0.7, top_p=0.95) +outputs = llm.generate(prompts, sampling_params) +``` + +### 7.4 CI/CD Integration + +#### GitHub Actions Workflows +- **Pre-commit**: Code quality checks +- **Unit Tests**: Automated test execution +- **Integration Tests**: End-to-end validation +- **Deployment**: Automated deployment on release + +#### Workflow Triggers +- Push to main: Run full test suite +- Pull requests: Run pre-commit and unit tests +- Tags: Trigger deployment pipeline + +--- + +## ๐Ÿ“Š Implementation Checklist + +### Phase 1: Exploration & Setup +- [ ] Complete codebase exploration +- [ ] Document architecture and data flows +- [ ] Set up development environment +- [ ] Verify all dependencies + +### Phase 2: Build & Validation +- [ ] Build all components successfully +- [ ] Run unit test suite +- [ ] Validate integration tests +- [ ] Benchmark performance baseline + +### Phase 3: Debug & Audit +- [ ] Set up debugging infrastructure +- [ ] Run code quality audits +- [ ] Security vulnerability scan +- [ ] Performance profiling + +### Phase 4: Deploy & Integrate +- [ ] Set up production environment +- [ ] Deploy services +- [ ] Configure monitoring +- [ ] Test integrations +- [ ] Document deployment procedures + +--- + +## ๐Ÿ“ Notes + +### Key Configuration Files +- Training: `examples/train/math_tir/train_qwen3_4b_adpo.sh` +- Curriculum: `curriculum_train/scripts/curriculum_train.sh` +- Evaluation: `curriculum_train/question_evaluate/evaluate.sh` +- Tools: `curriculum_train/vllm_service_init/start_vllm_server_tool.py` + +### Critical Environment Variables +```bash +export STORAGE_PATH="/path/to/storage" +export HUGGINGFACENAME="Qwen/Qwen3-4B-Base" +export WANDB_API_KEY="your_key" +export VLLM_DISABLE_COMPILE_CACHE=1 +``` + +### Storage Structure +``` +$STORAGE_PATH/ +โ”œโ”€โ”€ evaluation/ +โ”œโ”€โ”€ models/ +โ”œโ”€โ”€ generated_question/ +โ””โ”€โ”€ temp_results/ +``` + +--- + +## ๐Ÿ”„ Maintenance + +### Regular Tasks +- Weekly dependency updates +- Monthly security audits +- Quarterly performance reviews +- Continuous monitoring of training jobs + +### Update Procedures +1. Test in development environment +2. Run full test suite +3. Deploy to staging +4. Validate in staging +5. Deploy to production +6. Monitor for issues + +--- + +*Last Updated: 2025-01-XX* +*Version: 1.0* diff --git a/QUICK_REFERENCE.md b/QUICK_REFERENCE.md new file mode 100644 index 0000000..977af8f --- /dev/null +++ b/QUICK_REFERENCE.md @@ -0,0 +1,242 @@ +# Agent0 Series - Quick Reference Guide + +## ๐Ÿš€ Quick Start + +### 1. Setup Environment +```bash +# Install dependencies +make install + +# Validate build +make build +``` + +### 2. Explore Codebase +```bash +# Full exploration +make explore + +# Specific components +make explore-training +make explore-tools +``` + +### 3. Run Tests +```bash +# Quick tests +make test-quick + +# Full test suite +make test-all +``` + +### 4. Debug Issues +```bash +# Check GPU +make debug-gpu + +# Check Ray cluster +make debug-ray + +# Check configuration +make debug-config +``` + +## ๐Ÿ“‹ Common Commands + +### Exploration +```bash +./scripts/explore_codebase.sh [component] +# Components: all, training, tools, evaluation, dependencies +``` + +### Build Validation +```bash +./scripts/validate_build.sh +``` + +### Testing +```bash +./scripts/run_tests.sh [type] +# Types: unit, integration, quick, all +``` + +### Debugging +```bash +./scripts/debug_helper.sh [command] +# Commands: gpu-status, ray-status, check-logs, test-sandbox, +# test-vllm, memory-profile, check-config +``` + +### Auditing +```bash +./scripts/audit_code.sh [type] +# Types: all, security, quality, dependencies +``` + +## ๐Ÿ”ง Configuration + +### Required Environment Variables +```bash +export STORAGE_PATH="/path/to/storage" +export HUGGINGFACENAME="Qwen/Qwen3-4B-Base" +export WANDB_API_KEY="your_key" +export VLLM_DISABLE_COMPILE_CACHE=1 +``` + +### Storage Structure +``` +$STORAGE_PATH/ +โ”œโ”€โ”€ evaluation/ +โ”œโ”€โ”€ models/ +โ”œโ”€โ”€ generated_question/ +โ””โ”€โ”€ temp_results/ +``` + +## ๐Ÿ—๏ธ Training Workflow + +### 1. Train Curriculum Agent +```bash +cd Agent0/curriculum_train/ +bash scripts/curriculum_train.sh \ + Qwen/Qwen3-4B-Base \ + Qwen/Qwen3-4B-Base \ + qwen3_4b_curriculum_v1 +``` + +### 2. Generate Questions +```bash +curriculum_agent_path=${STORAGE_PATH}/models/qwen3_4b_curriculum_v1/global_step_5/actor/huggingface +experiment_name=qwen3_4b_executor_v1 + +bash question_generate/question_generate.bash \ + $curriculum_agent_path 1000 $experiment_name +``` + +### 3. Evaluate Questions +```bash +executor_agent_path=Qwen/Qwen3-4B-Base +bash question_evaluate/evaluate.sh \ + $executor_agent_path $experiment_name +``` + +### 4. Train Executor Agent +```bash +cd ../executor_train +bash examples/train/math_tir/train_qwen3_4b_adpo.sh +``` + +## ๐Ÿ› Troubleshooting + +### GPU Memory Issues +```bash +# Monitor GPU +watch -n 1 nvidia-smi + +# Reduce batch size in config files +# Look for: batch_size, micro_batch_size +``` + +### Ray Connection Issues +```bash +# Start Ray cluster +ray start --head + +# Check status +ray status + +# Debug +make debug-ray +``` + +### SandboxFusion Issues +```bash +# Test connection +make debug-sandbox + +# Check configuration +grep SANDBOX_API_URLS Agent0/curriculum_train/vllm_service_init/start_vllm_server_tool.py +``` + +### Model Loading Issues +```bash +# Verify model access +python3 -c "from transformers import AutoModel; \ + AutoModel.from_pretrained('Qwen/Qwen3-4B-Base')" + +# Check checkpoint +python3 Agent0/curriculum_train/scripts/model_merger.py --check-only +``` + +## ๐Ÿ“Š Monitoring + +### Training Metrics +- **WandB**: Automatic logging during training +- **TensorBoard**: Local logs in `logs/` directory +- **Ray Dashboard**: `http://localhost:8265` (if Ray is running) + +### System Monitoring +```bash +# GPU usage +make debug-gpu + +# Memory usage +make debug-memory + +# Check logs +make debug-logs +``` + +## ๐Ÿ”— Integration Points + +### Load Trained Model +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint") +tokenizer = AutoTokenizer.from_pretrained("path/to/checkpoint") +``` + +### Use Evaluation API +```python +from eval_service import EvaluationAPI + +api = EvaluationAPI(endpoint="http://eval-service:8000") +score = api.evaluate(model_output, ground_truth) +``` + +### Use Tool Integration +```python +from verl_tool.servers import SandboxFusionTool + +tool = SandboxFusionTool(config=config) +result = tool.execute(code="print(1+1)") +``` + +## ๐Ÿ“š Key Files + +### Training Scripts +- `Agent0/curriculum_train/scripts/curriculum_train.sh` +- `Agent0/executor_train/examples/train/math_tir/train_qwen3_4b_adpo.sh` + +### Configuration +- `Agent0/curriculum_train/vllm_service_init/start_vllm_server_tool.py` +- `Agent0/curriculum_train/examples/config.yaml` + +### Evaluation +- `Agent0/curriculum_train/question_evaluate/evaluate.sh` +- `Agent0/executor_train/eval_service/scripts/start_api_service.sh` + +## ๐Ÿ†˜ Getting Help + +1. **Check Documentation**: See `PLAN.md` for comprehensive guide +2. **Run Diagnostics**: `make debug-config` +3. **Check Logs**: `make debug-logs` +4. **Validate Setup**: `make validate` + +## ๐Ÿ“ Notes + +- Always check GPU availability before training +- Ensure SandboxFusion is running before curriculum training +- Set all required environment variables before starting +- Monitor disk space for checkpoints and generated data diff --git a/scripts/audit_code.sh b/scripts/audit_code.sh new file mode 100755 index 0000000..b980ef4 --- /dev/null +++ b/scripts/audit_code.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# Agent0 Code Audit Script +# Usage: ./scripts/audit_code.sh [audit_type] + +set -e + +AUDIT_TYPE=${1:-"all"} +BASE_DIR="/workspace/Agent0" + +echo "๐Ÿ”’ Agent0 Code Audit" +echo "====================" +echo "" + +# Install audit tools if needed +install_audit_tools() { + echo "๐Ÿ“ฆ Installing audit tools..." + pip install --quiet pylint black flake8 bandit safety 2>/dev/null || { + echo "โš ๏ธ Some tools may already be installed" + } +} + +case $AUDIT_TYPE in + "all") + install_audit_tools + $0 security + $0 quality + $0 dependencies + ;; + + "security") + echo "๐Ÿ” Security Audit" + echo "----------------" + + if command -v bandit &> /dev/null; then + echo "Running Bandit security scan..." + bandit -r "$BASE_DIR" -ll -f json -o /tmp/bandit_report.json 2>/dev/null || { + echo "โš ๏ธ Security issues found. Check /tmp/bandit_report.json" + } + echo "โœ… Security scan complete" + else + echo "โš ๏ธ Bandit not installed. Install with: pip install bandit" + fi + echo "" + + if command -v safety &> /dev/null; then + echo "Checking for known vulnerabilities..." + safety check --json 2>/dev/null || { + echo "โš ๏ธ Vulnerable packages found" + } + echo "โœ… Dependency vulnerability check complete" + else + echo "โš ๏ธ Safety not installed. Install with: pip install safety" + fi + echo "" + ;; + + "quality") + echo "๐Ÿ“Š Code Quality Audit" + echo "--------------------" + + if command -v black &> /dev/null; then + echo "Checking code formatting with Black..." + black --check --diff "$BASE_DIR" 2>/dev/null || { + echo "โš ๏ธ Code formatting issues found" + } + echo "โœ… Formatting check complete" + else + echo "โš ๏ธ Black not installed" + fi + echo "" + + if command -v flake8 &> /dev/null; then + echo "Running Flake8 linting..." + flake8 "$BASE_DIR" --max-line-length=120 --exclude=venv,__pycache__,*.egg-info --count --statistics 2>/dev/null || { + echo "โš ๏ธ Linting issues found" + } + echo "โœ… Linting complete" + else + echo "โš ๏ธ Flake8 not installed" + fi + echo "" + + if command -v pylint &> /dev/null; then + echo "Running Pylint analysis..." + pylint "$BASE_DIR" --disable=all --enable=E,W --max-line-length=120 2>/dev/null | head -50 || { + echo "โš ๏ธ Code quality issues found" + } + echo "โœ… Pylint analysis complete" + else + echo "โš ๏ธ Pylint not installed" + fi + echo "" + ;; + + "dependencies") + echo "๐Ÿ“ฆ Dependency Audit" + echo "-------------------" + + echo "Checking for outdated packages..." + pip list --outdated 2>/dev/null | head -20 || { + echo "โš ๏ธ Could not check outdated packages" + } + echo "" + + echo "Checking for duplicate dependencies..." + # Check for version conflicts in requirements files + if [ -f "$BASE_DIR/requirements.txt" ]; then + echo "Main requirements:" + grep -E "^[a-zA-Z]" "$BASE_DIR/requirements.txt" | cut -d'=' -f1 | sort | uniq -d || { + echo " โœ… No duplicates found" + } + fi + echo "" + + echo "Checking license compatibility..." + echo "โš ๏ธ Manual license check recommended" + echo " Verify all dependencies are compatible with Apache 2.0" + echo "" + ;; + + *) + echo "Unknown audit type: $AUDIT_TYPE" + echo "Available types: all, security, quality, dependencies" + exit 1 + ;; +esac + +echo "โœ… Audit complete!" diff --git a/scripts/debug_helper.sh b/scripts/debug_helper.sh new file mode 100755 index 0000000..4243e14 --- /dev/null +++ b/scripts/debug_helper.sh @@ -0,0 +1,156 @@ +#!/bin/bash +# Agent0 Debug Helper Script +# Usage: ./scripts/debug_helper.sh [command] [args...] + +set -e + +COMMAND=${1:-"help"} +BASE_DIR="/workspace/Agent0" + +case $COMMAND in + "help") + echo "๐Ÿ› Agent0 Debug Helper" + echo "======================" + echo "" + echo "Usage: ./scripts/debug_helper.sh [command] [args...]" + echo "" + echo "Commands:" + echo " gpu-status - Show GPU status and memory usage" + echo " ray-status - Check Ray cluster status" + echo " check-logs - Show recent log files" + echo " test-sandbox - Test SandboxFusion connection" + echo " test-vllm - Test vLLM server connection" + echo " memory-profile - Profile memory usage" + echo " check-config - Validate configuration files" + echo "" + ;; + + "gpu-status") + echo "๐ŸŽฎ GPU Status" + echo "------------" + nvidia-smi --query-gpu=index,name,memory.used,memory.total,utilization.gpu --format=csv,noheader,nounits || { + echo "โš ๏ธ nvidia-smi not available (may not have GPU)" + } + ;; + + "ray-status") + echo "โ˜€๏ธ Ray Cluster Status" + echo "---------------------" + python3 -c " +import ray +try: + ray.init(address='auto', ignore_reinit_error=True) + print('โœ… Ray connected') + print(f'Nodes: {len(ray.nodes())}') + print(f'Resources: {ray.available_resources()}') +except Exception as e: + print(f'โš ๏ธ Ray not initialized: {e}') + print('Start Ray with: ray start --head') +" || echo "โš ๏ธ Ray check failed" + ;; + + "check-logs") + echo "๐Ÿ“‹ Recent Logs" + echo "-------------" + LOG_DIRS=( + "$BASE_DIR/curriculum_train" + "$BASE_DIR/executor_train" + ) + + for dir in "${LOG_DIRS[@]}"; do + if [ -d "$dir" ]; then + echo "Logs in $dir:" + find "$dir" -name "*.log" -type f -mtime -1 2>/dev/null | head -5 || echo " No recent logs" + fi + done + ;; + + "test-sandbox") + SANDBOX_URL=${2:-"http://localhost:8000/run_code"} + echo "๐Ÿงช Testing SandboxFusion" + echo "-----------------------" + echo "URL: $SANDBOX_URL" + + curl -X POST "$SANDBOX_URL" \ + -H "Content-Type: application/json" \ + -d '{"code": "print(1+1)", "language": "python"}' \ + -w "\nHTTP Status: %{http_code}\n" || { + echo "โŒ Sandbox connection failed" + echo "Make sure SandboxFusion is running" + } + ;; + + "test-vllm") + VLLM_URL=${2:-"http://localhost:8000/v1/completions"} + echo "๐Ÿš€ Testing vLLM Server" + echo "---------------------" + echo "URL: $VLLM_URL" + + curl -X POST "$VLLM_URL" \ + -H "Content-Type: application/json" \ + -d '{"model": "test", "prompt": "Hello", "max_tokens": 10}' \ + -w "\nHTTP Status: %{http_code}\n" || { + echo "โŒ vLLM connection failed" + echo "Make sure vLLM server is running" + } + ;; + + "memory-profile") + echo "๐Ÿ’พ Memory Profiling" + echo "-------------------" + python3 -c " +import torch +import psutil +import os + +process = psutil.Process(os.getpid()) +mem_info = process.memory_info() +print(f'Process Memory: {mem_info.rss / 1024 / 1024:.2f} MB') + +if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + print(f'GPU {i} Memory:') + print(f' Allocated: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB') + print(f' Reserved: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB') +else: + print('โš ๏ธ CUDA not available') +" + ;; + + "check-config") + echo "โš™๏ธ Configuration Check" + echo "----------------------" + + # Check environment variables + echo "Environment Variables:" + for var in STORAGE_PATH HUGGINGFACENAME WANDB_API_KEY; do + if [ -n "${!var}" ]; then + echo " โœ… $var is set" + else + echo " โš ๏ธ $var is not set" + fi + done + echo "" + + # Check config files + echo "Configuration Files:" + CONFIG_FILES=( + "$BASE_DIR/curriculum_train/vllm_service_init/start_vllm_server_tool.py" + "$BASE_DIR/curriculum_train/scripts/curriculum_train.sh" + ) + + for file in "${CONFIG_FILES[@]}"; do + if [ -f "$file" ]; then + echo " โœ… $(basename $file) exists" + else + echo " โŒ $(basename $file) missing" + fi + done + ;; + + *) + echo "Unknown command: $COMMAND" + $0 help + exit 1 + ;; +esac diff --git a/scripts/explore_codebase.sh b/scripts/explore_codebase.sh new file mode 100755 index 0000000..d568758 --- /dev/null +++ b/scripts/explore_codebase.sh @@ -0,0 +1,92 @@ +#!/bin/bash +# Agent0 Codebase Explorer Script +# Usage: ./scripts/explore_codebase.sh [component] + +set -e + +COMPONENT=${1:-"all"} +BASE_DIR="/workspace/Agent0" + +echo "๐Ÿ” Agent0 Codebase Explorer" +echo "============================" +echo "" + +case $COMPONENT in + "all") + echo "๐Ÿ“Š Overall Statistics" + echo "---------------------" + echo "Python files: $(find $BASE_DIR -name "*.py" | wc -l)" + echo "Shell scripts: $(find $BASE_DIR -name "*.sh" | wc -l)" + echo "Config files: $(find $BASE_DIR -name "*.yaml" | wc -l)" + echo "" + + echo "๐Ÿ—๏ธ Key Components" + echo "-----------------" + echo "Curriculum Training:" + find $BASE_DIR/curriculum_train -maxdepth 2 -type d | head -10 + echo "" + echo "Executor Training:" + find $BASE_DIR/executor_train -maxdepth 2 -type d | head -10 + echo "" + + echo "๐Ÿ“ Entry Points" + echo "---------------" + grep -r "if __name__" $BASE_DIR --include="*.py" | head -10 + echo "" + ;; + + "training") + echo "๐ŸŽ“ Training Scripts" + echo "-------------------" + find $BASE_DIR -name "*train*.sh" -type f + echo "" + + echo "๐Ÿ“‹ Training Configs" + echo "------------------" + find $BASE_DIR -name "*.yaml" -path "*/train*" -o -name "*config*.yaml" | head -20 + echo "" + ;; + + "tools") + echo "๐Ÿ”ง Tool Servers" + echo "--------------" + find $BASE_DIR/executor_train/verl_tool/servers -name "*.py" -type f | grep -v test | grep -v __pycache__ + echo "" + + echo "๐Ÿงช Tool Tests" + echo "-------------" + find $BASE_DIR/executor_train/verl_tool/servers/tests -name "test_*.py" -type f + echo "" + ;; + + "evaluation") + echo "๐Ÿ“Š Evaluation Components" + echo "-----------------------" + find $BASE_DIR -path "*/eval*" -name "*.py" -type f | head -20 + echo "" + + echo "๐Ÿ“ˆ Evaluation Scripts" + echo "--------------------" + find $BASE_DIR -name "*evaluate*.sh" -o -name "*evaluate*.py" | head -10 + echo "" + ;; + + "dependencies") + echo "๐Ÿ“ฆ Dependencies" + echo "--------------" + echo "Main requirements:" + cat $BASE_DIR/requirements.txt | head -20 + echo "" + echo "Curriculum requirements:" + cat $BASE_DIR/curriculum_train/requirements.txt | head -20 + echo "" + ;; + + *) + echo "Unknown component: $COMPONENT" + echo "Available components: all, training, tools, evaluation, dependencies" + exit 1 + ;; +esac + +echo "โœ… Exploration complete!" diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh new file mode 100755 index 0000000..9bb367e --- /dev/null +++ b/scripts/run_tests.sh @@ -0,0 +1,92 @@ +#!/bin/bash +# Agent0 Test Runner Script +# Usage: ./scripts/run_tests.sh [test_type] + +set -e + +TEST_TYPE=${1:-"unit"} +BASE_DIR="/workspace/Agent0" + +echo "๐Ÿงช Agent0 Test Runner" +echo "=====================" +echo "" + +case $TEST_TYPE in + "unit") + echo "๐Ÿ“ Running Unit Tests" + echo "---------------------" + + # VeRL unit tests + if [ -d "$BASE_DIR/executor_train/verl/tests" ]; then + echo "Running VeRL unit tests..." + cd $BASE_DIR/executor_train/verl + python3 -m pytest tests/ -v -k "not gpu" --tb=short -x || { + echo "โš ๏ธ Some VeRL tests failed (this may be expected)" + } + cd - > /dev/null + echo "" + fi + + # Tool server tests + if [ -d "$BASE_DIR/executor_train/verl_tool/servers/tests" ]; then + echo "Running tool server tests..." + cd $BASE_DIR/executor_train/verl_tool/servers/tests + python3 -m pytest test_*.py -v --tb=short -x || { + echo "โš ๏ธ Some tool tests failed (may require external services)" + } + cd - > /dev/null + echo "" + fi + + # Evaluation service tests + if [ -d "$BASE_DIR/executor_train/eval_service/test" ]; then + echo "Running evaluation service tests..." + cd $BASE_DIR/executor_train/eval_service/test + python3 -m pytest test_*.py -v --tb=short -x || { + echo "โš ๏ธ Some evaluation tests failed" + } + cd - > /dev/null + echo "" + fi + ;; + + "integration") + echo "๐Ÿ”— Running Integration Tests" + echo "----------------------------" + echo "โš ๏ธ Integration tests require GPU and external services" + echo "Skipping for now..." + ;; + + "quick") + echo "โšก Running Quick Tests" + echo "---------------------" + + # Quick import tests + echo "Testing imports..." + python3 -c " +import torch +import transformers +import ray +print('โœ… Core imports OK') +" || exit 1 + + # Quick VeRL import + cd $BASE_DIR/executor_train/verl 2>/dev/null && python3 -c "import verl; print('โœ… VeRL import OK')" || echo "โš ๏ธ VeRL not installed" + echo "" + ;; + + "all") + echo "๐Ÿ”„ Running All Tests" + echo "-------------------" + $0 unit + $0 integration + ;; + + *) + echo "Unknown test type: $TEST_TYPE" + echo "Available types: unit, integration, quick, all" + exit 1 + ;; +esac + +echo "โœ… Test run complete!" diff --git a/scripts/validate_build.sh b/scripts/validate_build.sh new file mode 100755 index 0000000..047d755 --- /dev/null +++ b/scripts/validate_build.sh @@ -0,0 +1,124 @@ +#!/bin/bash +# Agent0 Build Validation Script +# Usage: ./scripts/validate_build.sh + +set -e + +echo "๐Ÿ—๏ธ Agent0 Build Validation" +echo "===========================" +echo "" + +# Check Python version +echo "๐Ÿ Python Version Check" +python3 --version +if [ $? -ne 0 ]; then + echo "โŒ Python not found" + exit 1 +fi +echo "โœ… Python OK" +echo "" + +# Check CUDA availability +echo "๐ŸŽฎ CUDA Check" +python3 -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA Available: {torch.cuda.is_available()}'); print(f'CUDA Version: {torch.version.cuda if torch.cuda.is_available() else \"N/A\"}')" || { + echo "โŒ PyTorch/CUDA check failed" + exit 1 +} +echo "โœ… CUDA OK" +echo "" + +# Check critical packages +echo "๐Ÿ“ฆ Critical Package Check" +PACKAGES=( + "torch" + "transformers" + "ray" + "vllm" + "flash_attn" + "accelerate" + "wandb" +) + +for pkg in "${PACKAGES[@]}"; do + python3 -c "import $pkg; print(f'โœ… $pkg: OK')" 2>/dev/null || { + echo "โŒ $pkg: MISSING" + MISSING=1 + } +done + +if [ -n "$MISSING" ]; then + echo "" + echo "โš ๏ธ Some packages are missing. Install with:" + echo " pip install -r Agent0/requirements.txt" + exit 1 +fi +echo "" + +# Check VeRL installation +echo "๐Ÿ”ฌ VeRL Framework Check" +cd /workspace/Agent0/executor_train/verl 2>/dev/null || { + echo "โŒ VeRL directory not found" + exit 1 +} + +python3 -c "import verl; print('โœ… VeRL: OK')" 2>/dev/null || { + echo "โš ๏ธ VeRL not installed. Install with:" + echo " cd Agent0/executor_train/verl && pip install -e ." +} +echo "" + +# Check file structure +echo "๐Ÿ“ File Structure Check" +REQUIRED_DIRS=( + "Agent0/curriculum_train" + "Agent0/executor_train" + "Agent0/curriculum_train/scripts" + "Agent0/executor_train/examples" +) + +for dir in "${REQUIRED_DIRS[@]}"; do + if [ -d "/workspace/$dir" ]; then + echo "โœ… $dir exists" + else + echo "โŒ $dir missing" + exit 1 + fi +done +echo "" + +# Check configuration files +echo "โš™๏ธ Configuration Files Check" +CONFIG_FILES=( + "Agent0/requirements.txt" + "Agent0/curriculum_train/requirements.txt" + "Agent0/curriculum_train/scripts/curriculum_train.sh" +) + +for file in "${CONFIG_FILES[@]}"; do + if [ -f "/workspace/$file" ]; then + echo "โœ… $file exists" + else + echo "โš ๏ธ $file missing (may be optional)" + fi +done +echo "" + +# Check external services (if configured) +echo "๐ŸŒ External Services Check" +if [ -f "/workspace/Agent0/curriculum_train/vllm_service_init/start_vllm_server_tool.py" ]; then + echo "โœ… vLLM service script found" + # Check if sandbox URLs are configured + if grep -q "SANDBOX_API_URLS" /workspace/Agent0/curriculum_train/vllm_service_init/start_vllm_server_tool.py; then + echo "โš ๏ธ Sandbox URLs may need configuration" + fi +else + echo "โš ๏ธ vLLM service script not found" +fi +echo "" + +echo "โœ… Build validation complete!" +echo "" +echo "Next steps:" +echo "1. Configure SandboxFusion URLs if needed" +echo "2. Set environment variables (STORAGE_PATH, WANDB_API_KEY, etc.)" +echo "3. Run tests: ./scripts/run_tests.sh" From 3a3e14be835bcdb1a505b8183ffd08d3d53a3b78 Mon Sep 17 00:00:00 2001 From: Wes <93578022+Wbaker7702@users.noreply.github.com> Date: Thu, 4 Dec 2025 13:30:44 -0500 Subject: [PATCH 02/12] Update PLAN.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- PLAN.md | 27 +++------------------------ 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/PLAN.md b/PLAN.md index 31f1436..d228c5f 100644 --- a/PLAN.md +++ b/PLAN.md @@ -109,32 +109,11 @@ grep -r "FastAPI\|Flask" --include="*.py" #### Installation Steps -**Step 1: Base Environment** -```bash -cd /workspace/Agent0/Agent0 - -# Install base requirements -pip install -r requirements.txt - -# Install VeRL framework -pip install -e verl +**Step 1: Install Dependencies** +It is recommended to use the provided `Makefile` to install all dependencies. This ensures a consistent and correct setup. -# Install Flash Attention (requires CUDA) -pip install "flash-attn==2.8.3" --no-build-isolation -``` - -**Step 2: Curriculum Training Setup** ```bash -cd curriculum_train/ -pip install -r requirements.txt -``` - -**Step 3: Executor Training Setup** -```bash -cd executor_train/ -pip install -e verl -pip install -e verl_tool -``` +make install ### 3.2 External Service Setup From 82601e079a53520b2af34c228d5e3d77643b86c0 Mon Sep 17 00:00:00 2001 From: Wes <93578022+Wbaker7702@users.noreply.github.com> Date: Thu, 4 Dec 2025 13:30:57 -0500 Subject: [PATCH 03/12] Update scripts/audit_code.sh Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- scripts/audit_code.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/audit_code.sh b/scripts/audit_code.sh index b980ef4..72a0f35 100755 --- a/scripts/audit_code.sh +++ b/scripts/audit_code.sh @@ -3,6 +3,7 @@ # Usage: ./scripts/audit_code.sh [audit_type] set -e +set -o pipefail AUDIT_TYPE=${1:-"all"} BASE_DIR="/workspace/Agent0" From 5ab9acb0abc82a863de73707979abe34f22e9f06 Mon Sep 17 00:00:00 2001 From: Wes <93578022+Wbaker7702@users.noreply.github.com> Date: Thu, 4 Dec 2025 13:31:06 -0500 Subject: [PATCH 04/12] Update scripts/audit_code.sh Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- scripts/audit_code.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/audit_code.sh b/scripts/audit_code.sh index 72a0f35..32f75ae 100755 --- a/scripts/audit_code.sh +++ b/scripts/audit_code.sh @@ -6,7 +6,7 @@ set -e set -o pipefail AUDIT_TYPE=${1:-"all"} -BASE_DIR="/workspace/Agent0" +BASE_DIR=$(cd "$(dirname "$0")/.." && pwd)/Agent0 echo "๐Ÿ”’ Agent0 Code Audit" echo "====================" From ce92802fb39b26c44db38c78f7082cc9f1c68833 Mon Sep 17 00:00:00 2001 From: Wes <93578022+Wbaker7702@users.noreply.github.com> Date: Thu, 4 Dec 2025 13:31:17 -0500 Subject: [PATCH 05/12] Update scripts/audit_code.sh Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- scripts/audit_code.sh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts/audit_code.sh b/scripts/audit_code.sh index 32f75ae..deba47c 100755 --- a/scripts/audit_code.sh +++ b/scripts/audit_code.sh @@ -107,9 +107,12 @@ case $AUDIT_TYPE in # Check for version conflicts in requirements files if [ -f "$BASE_DIR/requirements.txt" ]; then echo "Main requirements:" - grep -E "^[a-zA-Z]" "$BASE_DIR/requirements.txt" | cut -d'=' -f1 | sort | uniq -d || { + duplicates=$(grep -E "^[a-zA-Z]" "$BASE_DIR/requirements.txt" | cut -d'=' -f1 | sort | uniq -d) + if [ -z "$duplicates" ]; then echo " โœ… No duplicates found" - } + else + echo " โš ๏ธ Duplicates found: $duplicates" + fi fi echo "" From b95957f6aed8f3df4e403437dbb8dde0f4302001 Mon Sep 17 00:00:00 2001 From: Wes <93578022+Wbaker7702@users.noreply.github.com> Date: Thu, 4 Dec 2025 13:31:26 -0500 Subject: [PATCH 06/12] Update scripts/debug_helper.sh Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- scripts/debug_helper.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/debug_helper.sh b/scripts/debug_helper.sh index 4243e14..3e3f709 100755 --- a/scripts/debug_helper.sh +++ b/scripts/debug_helper.sh @@ -60,7 +60,7 @@ except Exception as e: for dir in "${LOG_DIRS[@]}"; do if [ -d "$dir" ]; then echo "Logs in $dir:" - find "$dir" -name "*.log" -type f -mtime -1 2>/dev/null | head -5 || echo " No recent logs" + find "$dir" -name "*.log" -type f -mtime -1 2>/dev/null | head -5 | grep . || echo " No recent logs" fi done ;; From 374c3fd2f07a43f308b9f2220d3b90c649fded33 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Sat, 3 Jan 2026 22:03:10 +0000 Subject: [PATCH 07/12] Refactor: Improve code structure and add new features This commit refactors the codebase to improve its structure and introduces new features. Key changes include: - **Code Structure Improvements:** - Enhanced argument parsing and configuration handling. - Refactored data loading and processing pipelines. - Improved worker and resource management. - Streamlined checkpointing and logging mechanisms. - **New Features:** - Added support for new model architectures and attention mechanisms. - Introduced advanced reward calculation and evaluation metrics. - Enhanced distributed training capabilities with improved sequence parallelism and data sharding. - Implemented adaptive KL control for PPO training. - **Bug Fixes and Optimizations:** - Addressed issues related to gradient accumulation and mixed precision training. - Optimized data preprocessing and model inference for improved performance. - Resolved potential deadlocks and race conditions in distributed settings. Co-authored-by: wbaker7702 --- .../reward_function/curriculum_reward.py | 85 +- .../examples/reward_function/math.py | 11 +- .../examples/reward_function/r1v.py | 8 +- .../question_evaluate/evaluate.py | 123 +- .../question_evaluate/upload.py | 20 +- .../question_generate/question_generate.py | 46 +- Agent0/curriculum_train/requirements.txt | 1 - .../curriculum_train/scripts/model_merger.py | 32 +- Agent0/curriculum_train/verl/__init__.py | 4 +- Agent0/curriculum_train/verl/protocol.py | 159 +- .../verl/single_controller/base/decorator.py | 61 +- .../verl/single_controller/base/worker.py | 22 +- .../single_controller/base/worker_group.py | 42 +- .../verl/single_controller/ray/__init__.py | 14 +- .../verl/single_controller/ray/base.py | 173 ++- .../curriculum_train/verl/trainer/config.py | 13 +- .../verl/trainer/core_algos.py | 53 +- .../verl/trainer/data_loader.py | 14 +- Agent0/curriculum_train/verl/trainer/main.py | 24 +- .../curriculum_train/verl/trainer/metrics.py | 31 +- .../verl/trainer/ray_trainer.py | 332 ++++- .../utils/checkpoint/checkpoint_manager.py | 15 +- .../checkpoint/fsdp_checkpoint_manager.py | 46 +- .../verl/utils/code_executor.py | 24 +- Agent0/curriculum_train/verl/utils/dataset.py | 113 +- .../verl/utils/flops_counter.py | 28 +- .../curriculum_train/verl/utils/fsdp_utils.py | 21 +- .../verl/utils/logger/gen_logger.py | 21 +- .../verl/utils/logger/logger.py | 17 +- .../verl/utils/model_utils.py | 4 +- .../verl/utils/py_functional.py | 8 +- .../verl/utils/seqlen_balancing.py | 38 +- .../curriculum_train/verl/utils/tokenizer.py | 19 +- .../verl/utils/torch_functional.py | 76 +- Agent0/curriculum_train/verl/utils/ulysses.py | 47 +- .../verl/workers/actor/config.py | 4 +- .../verl/workers/actor/dp_actor.py | 130 +- .../curriculum_train/verl/workers/config.py | 8 +- .../verl/workers/critic/dp_critic.py | 81 +- .../verl/workers/fsdp_workers.py | 228 ++- .../verl/workers/reward/__init__.py | 13 +- .../verl/workers/reward/config.py | 8 +- .../verl/workers/reward/function.py | 33 +- .../verl/workers/rollout/vllm_rollout_spmd.py | 75 +- .../verl/workers/sharding_manager/__init__.py | 6 +- .../workers/sharding_manager/fsdp_ulysses.py | 5 +- .../workers/sharding_manager/fsdp_vllm.py | 34 +- .../start_vllm_server_tool.py | 218 ++- Agent0/executor_train/eval_service/app.py | 89 +- Agent0/executor_train/eval_service/config.py | 34 +- .../eval_service/model_service.py | 410 +++-- .../eval_service/test/test_api.py | 62 +- .../eval_service/test/test_api_mp.py | 87 +- .../scripts/visualize_entropy.py | 120 +- .../examples/data_preprocess/full_hh_rlhf.py | 4 +- .../verl/examples/data_preprocess/geo3k.py | 8 +- .../data_preprocess/geo3k_multiturn_w_tool.py | 8 +- .../verl/examples/data_preprocess/gsm8k.py | 4 +- .../gsm8k_multiturn_w_interaction.py | 4 +- .../data_preprocess/gsm8k_multiturn_w_tool.py | 4 +- .../examples/data_preprocess/math_dataset.py | 4 +- .../examples/data_preprocess/multiturn.py | 10 +- .../preprocess_search_r1_dataset.py | 45 +- .../local_dense_retriever/download.py | 15 +- .../local_dense_retriever/retrieval_server.py | 101 +- .../split_placement/main_ppo_split.py | 36 +- .../split_placement/split_monkey_patch.py | 57 +- .../verl/recipe/char_count/create_dataset.py | 28 +- .../verl/recipe/char_count/reward_function.py | 4 +- .../verl/recipe/dapo/dapo_ray_trainer.py | 146 +- .../verl/recipe/dapo/main_dapo.py | 32 +- .../recipe/entropy/entropy_ray_trainer.py | 138 +- .../verl/recipe/entropy/main_entropy.py | 58 +- .../verl/recipe/entropy/reward.py | 12 +- .../recipe/entropy/reward_score/__init__.py | 7 +- .../reward_score/entropy_math/__init__.py | 26 +- .../reward_score/entropy_math/grader.py | 37 +- .../recipe/genrm_remote/reward_function.py | 8 +- .../verl/recipe/minicpmo/rl_dataset.py | 156 +- .../verl/recipe/prime/main_prime.py | 20 +- .../verl/recipe/prime/prime_core_algos.py | 94 +- .../verl/recipe/prime/prime_dp_rm.py | 142 +- .../verl/recipe/prime/prime_fsdp_workers.py | 138 +- .../verl/recipe/prime/prime_ray_trainer.py | 195 ++- .../verl/recipe/r1/data_process.py | 78 +- .../verl/recipe/r1/main_eval.py | 3 +- .../verl/recipe/r1/reward_score.py | 11 +- .../verl/recipe/r1/tasks/livecodebench.py | 8 +- .../verl/recipe/r1/tasks/math.py | 4 +- .../verl/recipe/retool/retool.py | 16 +- .../retool_multi_turn_sft_preprocess.py | 4 +- .../recipe/retool/retool_sft_preprocess.py | 7 +- .../verl/recipe/spin/core_algos.py | 35 +- .../verl/recipe/spin/dp_actor.py | 141 +- .../verl/recipe/spin/fsdp_workers.py | 250 +++- .../verl/recipe/spin/main_spin.py | 27 +- .../verl/recipe/spin/spin_trainer.py | 727 ++++++--- .../verl/recipe/sppo/dp_actor.py | 80 +- .../verl/recipe/sppo/main_sppo.py | 32 +- .../verl/recipe/sppo/sppo_ray_trainer.py | 119 +- .../verl/recipe/sppo/sppo_worker.py | 55 +- .../verl/scripts/converter_hf_to_mcore.py | 390 +++-- .../executor_train/verl/scripts/diagnose.py | 52 +- .../verl/scripts/init_random_model.py | 44 +- .../verl/scripts/legacy_model_merger.py | 243 ++- .../experimental/agent_loop/agent_utils.py | 20 +- .../agent_loop/test_basic_agent_loop.py | 36 +- .../interactions/test_gsm8k_interaction.py | 157 +- .../interactions/test_interaction_registry.py | 42 +- .../verl/tests/models/test_transformer.py | 30 +- .../tests/models/test_transformers_ulysses.py | 107 +- .../check_worker_alive/main.py | 6 +- .../detached_worker/client.py | 14 +- .../detached_worker/server.py | 24 +- .../test_auto_padding_on_cpu.py | 52 +- .../test_colocated_workers.py | 16 +- .../test_colocated_workers_fused.py | 16 +- .../single_controller/test_data_transfer.py | 10 +- .../test_decorator_on_cpu.py | 44 +- .../test_driverfunc_to_worker.py | 7 +- .../test_fused_workers_on_cpu.py | 4 +- .../test_high_level_scheduling_api.py | 75 +- .../single_controller/test_ray_collectives.py | 31 +- .../test_ray_local_envs_on_cpu.py | 10 +- .../verl/tests/single_controller/test_rvdz.py | 9 +- .../test_worker_group_basics.py | 24 +- .../test_worker_group_torch.py | 30 +- .../special_distributed/test_fsdp_ckpt.py | 32 +- .../special_distributed/test_tensor_dict.py | 48 +- .../tests/special_e2e/check_custom_rwd_fn.py | 8 +- .../verl/tests/special_e2e/check_results.py | 4 +- .../special_e2e/envs/digit_completion/task.py | 36 +- .../envs/digit_completion/tokenizer.py | 4 +- .../special_e2e/sft/test_sp_loss_match.py | 48 +- .../tests/special_sanity/check_api_docs.py | 12 +- .../special_sanity/check_device_api_usage.py | 4 +- .../special_sanity/check_docs_time_info.py | 5 +- .../tests/special_sanity/check_docstrings.py | 20 +- .../special_sanity/check_pr_description.py | 8 +- .../tests/special_sanity/check_pr_title.py | 24 +- .../tests/special_sanity/test_config_docs.py | 12 +- .../special_sanity/type_coverage_check.py | 49 +- .../special_sanity/validate_imported_docs.py | 12 +- .../special_sanity/validate_structure.py | 19 +- .../special_standalone/test_memory_buffers.py | 16 +- .../verl/tests/test_protocol_on_cpu.py | 243 ++- .../verl/tests/tools/test_base_tool_on_cpu.py | 13 +- .../trainer/config/test_algo_config_on_cpu.py | 14 +- .../config/test_legacy_config_on_cpu.py | 40 +- .../trainer/ppo/test_core_algos_on_cpu.py | 18 +- .../trainer/ppo/test_metric_utils_on_cpu.py | 36 +- .../utils/ckpt/test_esi_save_ckpt_on_cpu.py | 24 +- .../test_multiturn_sft_dataset_on_cpu.py | 96 +- .../utils/dataset/test_rl_dataset_on_cpu.py | 16 +- .../test_sandbox_fusion_on_cpu.py | 225 ++- .../utils/reward_score/test_sandbox_on_cpu.py | 63 +- .../tests/utils/test_activation_offload.py | 44 +- .../verl/tests/utils/test_flops_counter.py | 10 +- .../verl/tests/utils/test_fs_on_cpu.py | 4 +- .../tests/utils/test_import_utils_on_cpu.py | 4 +- .../tests/utils/test_linear_cross_entropy.py | 185 ++- .../utils/test_linear_cross_entropy_tp.py | 160 +- .../verl/tests/utils/test_model_on_cpu.py | 37 +- .../verl/tests/utils/test_nvtx_profile.py | 26 +- .../tests/utils/test_rollout_trace_on_cpu.py | 31 +- .../verl/tests/utils/test_seqlen_balancing.py | 15 +- .../tests/utils/test_timeout_decorator_cpu.py | 24 +- .../verl/tests/utils/test_torch_functional.py | 10 +- .../reward_manager/test_registry_on_cpu.py | 10 +- .../workers/rollout/async_rollout_utils.py | 16 +- .../rollout/perf/vllm_async_rollout.py | 17 +- .../rollout/rollout_vllm/run_fsdp_vllm.py | 54 +- .../rollout_vllm/test_vllm_chat_scheduler.py | 36 +- .../test_vllm_model_rope_scaling.py | 38 +- .../rollout/rollout_vllm/test_vllm_spmd.py | 64 +- .../rollout/test_async_sglang_server.py | 25 +- .../test_custom_completion_callback.py | 82 +- .../tests/workers/rollout/test_hf_rollout.py | 65 +- .../test_sglang_async_rollout_mcp_tools.py | 131 +- ...t_sglang_async_rollout_multimodal_delta.py | 47 +- .../test_sglang_async_rollout_search_tools.py | 150 +- .../test_sglang_async_rollout_sf_tools.py | 155 +- ...test_sglang_async_rollout_w_interaction.py | 56 +- .../test_sglang_async_rollout_w_tools.py | 20 +- .../rollout/test_sglang_multi_interaction.py | 60 +- .../tests/workers/rollout/test_sglang_spmd.py | 36 +- .../tests/workers/rollout/utils_sglang.py | 22 +- Agent0/executor_train/verl/verl/__init__.py | 4 +- .../experimental/agent_loop/agent_loop.py | 135 +- .../agent_loop/single_turn_agent_loop.py | 13 +- .../agent_loop/tool_agent_loop.py | 71 +- .../dynamic_dataset/dynamicgen_dataset.py | 6 +- .../verl/verl/interactions/base.py | 23 +- .../verl/interactions/gsm8k_interaction.py | 5 +- .../utils/interaction_registry.py | 8 +- .../verl/model_merger/base_model_merger.py | 112 +- .../verl/model_merger/fsdp_model_merger.py | 77 +- .../model_merger/megatron_model_merger.py | 48 +- .../megatron/checkpoint_utils/llama_loader.py | 102 +- .../llama_loader_depracated.py | 137 +- .../megatron/checkpoint_utils/llama_saver.py | 90 +- .../megatron/layers/parallel_attention.py | 177 ++- .../llama/megatron/layers/parallel_decoder.py | 24 +- .../llama/megatron/layers/parallel_linear.py | 4 +- .../llama/megatron/modeling_llama_megatron.py | 139 +- .../verl/models/mcore/config_converter.py | 62 +- .../verl/verl/models/mcore/loader.py | 236 ++- .../verl/verl/models/mcore/mbridge.py | 9 +- .../verl/verl/models/mcore/model_forward.py | 89 +- .../verl/models/mcore/model_forward_fused.py | 54 +- .../verl/models/mcore/model_initializer.py | 83 +- .../verl/verl/models/mcore/patch_v012.py | 59 +- .../verl/models/mcore/qwen2_5_vl/attention.py | 56 +- .../verl/models/mcore/qwen2_5_vl/model.py | 59 +- .../models/mcore/qwen2_5_vl/rope_utils.py | 78 +- .../models/mcore/qwen2_5_vl/vision_config.py | 4 +- .../models/mcore/qwen2_5_vl/vision_model.py | 51 +- .../qwen2_5_vl/vision_transformer_block.py | 49 +- .../verl/verl/models/mcore/registry.py | 36 +- .../verl/verl/models/mcore/saver.py | 147 +- .../verl/verl/models/mcore/util.py | 70 +- .../verl/models/mcore/weight_converter.py | 173 ++- .../megatron/checkpoint_utils/qwen2_loader.py | 114 +- .../qwen2_loader_depracated.py | 149 +- .../megatron/checkpoint_utils/qwen2_saver.py | 86 +- .../megatron/layers/parallel_attention.py | 145 +- .../qwen2/megatron/layers/parallel_decoder.py | 24 +- .../qwen2/megatron/modeling_qwen2_megatron.py | 151 +- .../verl/verl/models/registry.py | 22 +- .../verl/models/transformers/dense_common.py | 18 +- .../verl/verl/models/transformers/kimi_vl.py | 32 +- .../verl/verl/models/transformers/llama.py | 44 +- .../verl/models/transformers/monkey_patch.py | 96 +- .../verl/models/transformers/npu_patch.py | 8 +- .../verl/verl/models/transformers/qwen2.py | 26 +- .../verl/models/transformers/qwen2_5_vl.py | 32 +- .../verl/verl/models/transformers/qwen2_vl.py | 174 ++- Agent0/executor_train/verl/verl/protocol.py | 239 ++- .../verl/verl/single_controller/__init__.py | 4 +- .../verl/single_controller/base/decorator.py | 114 +- .../single_controller/base/megatron/worker.py | 44 +- .../base/megatron/worker_group.py | 24 +- .../verl/single_controller/base/worker.py | 36 +- .../single_controller/base/worker_group.py | 41 +- .../verl/verl/single_controller/ray/base.py | 228 ++- .../verl/single_controller/ray/megatron.py | 24 +- .../verl/third_party/sglang/parallel_state.py | 48 +- .../verl/verl/tools/base_tool.py | 11 +- .../verl/verl/tools/geo3k_tool.py | 15 +- .../verl/verl/tools/gsm8k_tool.py | 15 +- .../verl/verl/tools/mcp_base_tool.py | 20 +- .../verl/verl/tools/sandbox_fusion_tools.py | 45 +- .../executor_train/verl/verl/tools/schemas.py | 5 +- .../verl/verl/tools/search_tool.py | 45 +- .../utils/mcp_clients/McpClientManager.py | 5 +- .../verl/tools/utils/search_r1_like_utils.py | 26 +- .../verl/verl/tools/utils/tool_registry.py | 26 +- .../verl/verl/trainer/fsdp_sft_trainer.py | 255 +++- .../verl/verl/trainer/main_eval.py | 9 +- .../verl/verl/trainer/main_generation.py | 44 +- .../verl/verl/trainer/main_ppo.py | 78 +- .../verl/verl/trainer/ppo/core_algos.py | 165 ++- .../verl/verl/trainer/ppo/metric_utils.py | 91 +- .../verl/verl/trainer/ppo/ray_trainer.py | 511 +++++-- .../verl/verl/trainer/ppo/reward.py | 12 +- .../verl/verl/utils/__init__.py | 6 +- .../verl/verl/utils/activation_offload.py | 48 +- .../utils/checkpoint/checkpoint_manager.py | 38 +- .../checkpoint/fsdp_checkpoint_manager.py | 172 ++- .../checkpoint/megatron_checkpoint_manager.py | 208 ++- .../executor_train/verl/verl/utils/config.py | 4 +- .../utils/dataset/multiturn_sft_dataset.py | 76 +- .../verl/verl/utils/dataset/rl_dataset.py | 110 +- .../verl/verl/utils/dataset/rm_dataset.py | 45 +- .../verl/verl/utils/dataset/sft_dataset.py | 57 +- .../verl/verl/utils/dataset/vision_utils.py | 8 +- .../verl/utils/debug/trajectory_tracker.py | 6 +- .../executor_train/verl/verl/utils/device.py | 8 +- .../utils/experimental/torch_functional.py | 27 +- .../verl/verl/utils/flops_counter.py | 51 +- Agent0/executor_train/verl/verl/utils/fs.py | 35 +- .../verl/verl/utils/fsdp_utils.py | 150 +- .../executor_train/verl/verl/utils/hdfs_io.py | 8 +- .../verl/verl/utils/kernel/__init__.py | 1 - .../verl/verl/utils/kernel/kernels.py | 559 +++++-- .../verl/utils/kernel/linear_cross_entropy.py | 30 +- .../verl/utils/logger/aggregate_logger.py | 19 +- .../verl/utils/megatron/dist_checkpointing.py | 4 +- .../verl/verl/utils/megatron/memory.py | 7 +- .../verl/verl/utils/megatron/optimizer.py | 8 +- .../verl/utils/megatron/pipeline_parallel.py | 11 +- .../verl/utils/megatron/sequence_parallel.py | 10 +- .../verl/utils/megatron/tensor_parallel.py | 47 +- .../verl/verl/utils/megatron_utils.py | 263 +++- .../verl/verl/utils/memory_buffer.py | 35 +- .../executor_train/verl/verl/utils/model.py | 183 ++- .../verl/verl/utils/profiler/__init__.py | 14 +- .../verl/verl/utils/profiler/config.py | 6 +- .../verl/verl/utils/profiler/mstx_profile.py | 24 +- .../verl/verl/utils/profiler/nvtx_profile.py | 19 +- .../verl/verl/utils/profiler/performance.py | 12 +- .../verl/verl/utils/profiler/profile.py | 18 +- .../verl/verl/utils/py_functional.py | 46 +- .../verl/verl/utils/ray_utils.py | 4 +- .../verl/verl/utils/rendezvous/ray_backend.py | 12 +- .../verl/verl/utils/reward_score/__init__.py | 19 +- .../verl/verl/utils/reward_score/geo3k.py | 13 +- .../verl/verl/utils/reward_score/gsm8k.py | 4 +- .../verl/verl/utils/reward_score/math_dapo.py | 18 +- .../verl/utils/reward_score/math_verify.py | 8 +- .../utils/reward_score/prime_code/__init__.py | 12 +- .../reward_score/prime_code/testing_util.py | 91 +- .../utils/reward_score/prime_code/utils.py | 9 +- .../utils/reward_score/prime_math/__init__.py | 42 +- .../utils/reward_score/prime_math/grader.py | 41 +- .../reward_score/sandbox_fusion/__init__.py | 22 +- .../reward_score/sandbox_fusion/utils.py | 78 +- .../reward_score/search_r1_like_qa_em.py | 8 +- .../verl/verl/utils/rollout_trace.py | 30 +- .../verl/verl/utils/seqlen_balancing.py | 42 +- .../verl/verl/utils/tokenizer.py | 24 +- .../verl/verl/utils/torch_functional.py | 192 ++- .../verl/verl/utils/tracking.py | 81 +- .../executor_train/verl/verl/utils/ulysses.py | 73 +- .../verl/verl/utils/vllm_utils.py | 24 +- .../verl/verl/workers/actor/dp_actor.py | 328 ++-- .../verl/verl/workers/actor/megatron_actor.py | 190 ++- .../verl/verl/workers/critic/dp_critic.py | 127 +- .../verl/workers/critic/megatron_critic.py | 74 +- .../verl/verl/workers/fsdp_workers.py | 742 +++++++--- .../verl/verl/workers/megatron_workers.py | 463 ++++-- .../verl/verl/workers/reward_manager/batch.py | 39 +- .../verl/verl/workers/reward_manager/dapo.py | 32 +- .../verl/verl/workers/reward_manager/naive.py | 24 +- .../verl/verl/workers/reward_manager/prime.py | 61 +- .../reward_model/megatron/reward_model.py | 103 +- .../verl/verl/workers/rollout/async_server.py | 58 +- .../verl/workers/rollout/chat_scheduler.py | 152 +- .../verl/verl/workers/rollout/hf_rollout.py | 42 +- .../workers/rollout/naive/naive_rollout.py | 16 +- .../verl/verl/workers/rollout/schemas.py | 269 +++- .../sglang_rollout/async_sglang_server.py | 20 +- .../rollout/sglang_rollout/sglang_rollout.py | 408 +++-- .../workers/rollout/sglang_rollout/utils.py | 4 +- .../verl/verl/workers/rollout/tokenizer.py | 4 +- .../workers/rollout/vllm_rollout/__init__.py | 4 +- .../rollout/vllm_rollout/vllm_async_server.py | 81 +- .../rollout/vllm_rollout/vllm_rollout_spmd.py | 97 +- .../workers/sharding_manager/fsdp_sglang.py | 86 +- .../workers/sharding_manager/fsdp_ulysses.py | 5 +- .../workers/sharding_manager/fsdp_vllm.py | 127 +- .../sharding_manager/megatron_sglang.py | 30 +- .../workers/sharding_manager/megatron_vllm.py | 22 +- .../verl_tool/llm_agent/__init__.py | 2 +- .../verl_tool/llm_agent/config.py | 73 +- .../verl_tool/llm_agent/manager.py | 1318 +++++++++++------ .../verl_tool/llm_agent/tensor_helper.py | 98 +- .../verl_tool/llm_agent/utils.py | 33 +- .../verl_tool/llm_agent/vision_process.py | 101 +- .../verl_tool/llm_agent/vision_utils.py | 41 +- .../verl_tool/servers/ray_utils.py | 363 +++-- .../executor_train/verl_tool/servers/serve.py | 341 +++-- .../verl_tool/servers/tests/test_base.py | 38 +- .../servers/tests/test_bash_terminal_tool.py | 60 +- .../servers/tests/test_bing_search_tool.py | 109 +- .../verl_tool/servers/tests/test_crop_tool.py | 56 +- .../servers/tests/test_google_search_tool.py | 64 +- .../tests/test_mm_deepresearch_tool.py | 54 +- .../servers/tests/test_piston_server.py | 123 +- .../servers/tests/test_piston_tool.py | 63 +- .../servers/tests/test_python_code_tool.py | 49 +- .../servers/tests/test_python_oj_tool.py | 129 +- .../servers/tests/test_sandbox_fusion_tool.py | 96 +- .../tests/test_search_retrieval_tool.py | 125 +- .../servers/tests/test_serp_search_tool.py | 44 +- .../servers/tests/test_text_browser.py | 38 +- .../servers/tests/test_text_browser_multi.py | 16 +- .../verl_tool/servers/tools/__init__.py | 2 +- .../verl_tool/servers/tools/base.py | 107 +- .../verl_tool/servers/tools/bash_terminal.py | 100 +- .../verl_tool/servers/tools/bing_search.py | 257 ++-- .../verl_tool/servers/tools/finish.py | 20 +- .../verl_tool/servers/tools/google_search.py | 372 +++-- .../verl_tool/servers/tools/ipython_code.py | 298 ++-- .../verl_tool/servers/tools/mcp_interface.py | 25 +- .../verl_tool/servers/tools/piston.py | 158 +- .../verl_tool/servers/tools/pixel_reasoner.py | 290 ++-- .../verl_tool/servers/tools/python_code.py | 257 ++-- .../verl_tool/servers/tools/python_oj.py | 151 +- .../verl_tool/servers/tools/sandbox_fusion.py | 141 +- .../servers/tools/search_retrieval.py | 109 +- .../verl_tool/servers/tools/sql.py | 80 +- .../servers/tools/utils/bash_session.py | 232 +-- .../servers/tools/utils/deepsearch_utils.py | 545 ++++--- .../servers/tools/utils/retrieval_server.py | 103 +- .../servers/tools/utils/sql_executor.py | 162 +- .../servers/tools/utils/web_agent_utils.py | 85 +- .../executor_train/verl_tool/servers/utils.py | 39 +- .../verl_tool/trainer/main_ppo.py | 101 +- .../verl_tool/trainer/ppo/core_algos.py | 168 ++- .../verl_tool/trainer/ppo/metric_utils.py | 167 ++- .../verl_tool/trainer/ppo/ray_trainer.py | 444 ++++-- .../verl_tool/trainer/ppo/reward.py | 22 +- .../verl_tool/utils/dataset/rl_dataset.py | 157 +- .../verl_tool/workers/fsdp_workers.py | 43 +- .../workers/reward_manager/__init__.py | 11 +- .../workers/reward_manager/acecoder.py | 480 ++++-- .../workers/reward_manager/deepsearch.py | 36 +- .../workers/reward_manager/mathcoder.py | 125 +- .../workers/reward_manager/pixel_reasoner.py | 294 ++-- .../reward_manager/reward_score/__init__.py | 33 +- .../reward_manager/reward_score/torl_eval.py | 99 +- .../reward_manager/reward_score/torl_math.py | 91 +- .../workers/reward_manager/search_r1_qa_em.py | 204 ++- .../workers/reward_manager/sqlcoder.py | 175 ++- .../verl_tool/workers/reward_manager/torl.py | 252 ++-- .../verl_tool/workers/reward_manager/utils.py | 20 +- .../workers/reward_manager/wikiRL.py | 83 +- .../verl_tool/workers/rollout/async_server.py | 10 +- .../workers/rollout/chat_scheduler.py | 355 +++-- .../rollout/vllm_rollout/vllm_async_server.py | 46 +- .../executor_train/verl_tool/workers/utils.py | 102 +- Agent0/requirements.txt | 5 +- scripts/validate_build.sh | 6 + 424 files changed, 24350 insertions(+), 9569 deletions(-) diff --git a/Agent0/curriculum_train/examples/reward_function/curriculum_reward.py b/Agent0/curriculum_train/examples/reward_function/curriculum_reward.py index 6341a4d..0ff7b2b 100644 --- a/Agent0/curriculum_train/examples/reward_function/curriculum_reward.py +++ b/Agent0/curriculum_train/examples/reward_function/curriculum_reward.py @@ -27,7 +27,8 @@ from sklearn.cluster import AgglomerativeClustering import numpy as np -STORAGE_PATH = os.getenv("STORAGE_PATH","") +STORAGE_PATH = os.getenv("STORAGE_PATH", "") + def _bleu_distance_matrix(sentences): n = len(sentences) @@ -44,13 +45,13 @@ def _bleu_distance_matrix(sentences): dist[i, j] = dist[j, i] = 1 - score return dist + def cluster_share_per_problem( - problems, - distance_threshold: float = 0.5, - linkage: str = "average"): + problems, distance_threshold: float = 0.5, linkage: str = "average" +): if not problems: return [] - print('start clustering') + print("start clustering") start_time = time.time() dist_mat = _bleu_distance_matrix(problems) @@ -58,10 +59,10 @@ def cluster_share_per_problem( n_clusters=None, distance_threshold=distance_threshold, metric="precomputed", - linkage=linkage + linkage=linkage, ) labels = clustering.fit_predict(dist_mat) - print(f'end clustering, time: {time.time() - start_time}') + print(f"end clustering, time: {time.time() - start_time}") total = len(problems) cluster_size = Counter(labels) cluster_ratio = {lab: sz / total for lab, sz in cluster_size.items()} @@ -69,41 +70,52 @@ def cluster_share_per_problem( proportions = [cluster_ratio[lab] for lab in labels] return proportions + def generate_temp_filename(prefix="temp", suffix=".json"): timestamp = int(time.time() * 1000) rand_part = random.randint(0, 99999) return f"{STORAGE_PATH}/temp_results/{prefix}_{timestamp}_{rand_part}{suffix}" + + def split_list(lst, n=4): k, m = divmod(len(lst), n) - return [lst[i*k + min(i, m):(i+1)*k + min(i+1, m)] for i in range(n)] + return [lst[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)] + os.environ["NO_PROXY"] = "0.0.0.0,127.0.0.1" -def fetch(index,i): + +def fetch(index, i): response = requests.get(f"http://0.0.0.0:{5000+index}/hello?name={i}") return True + def generate_results(data): - datas = split_list(data,4) - random_names = [generate_temp_filename(prefix=f"temp_{i}", suffix=".json") for i in range(4)] + datas = split_list(data, 4) + random_names = [ + generate_temp_filename(prefix=f"temp_{i}", suffix=".json") for i in range(4) + ] for i in range(4): - with open(random_names[i],'w') as f: - json.dump(datas[i],f,indent=4) + with open(random_names[i], "w") as f: + json.dump(datas[i], f, indent=4) final_results = [] with ThreadPoolExecutor(max_workers=4) as executor: - futures = [executor.submit(fetch, i,random_names[i]) for i in range(4)] + futures = [executor.submit(fetch, i, random_names[i]) for i in range(4)] - for future in tqdm(as_completed(futures), total=len(futures), desc=" - Servers processing"): - future.result() # Simplified to just get the result + for future in tqdm( + as_completed(futures), total=len(futures), desc=" - Servers processing" + ): + future.result() # Simplified to just get the result for i in tqdm(range(4), desc=" - Reading result files", leave=False): - with open(random_names[i].replace('.json','_results.json'),'r') as f: + with open(random_names[i].replace(".json", "_results.json"), "r") as f: final_results.extend(json.load(f)) for i in range(4): - os.remove(random_names[i].replace('.json','_results.json')) + os.remove(random_names[i].replace(".json", "_results.json")) return final_results + def format_reward(predict: str) -> float: pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) format_match = re.fullmatch(pattern, predict) @@ -114,6 +126,7 @@ def accuracy_reward(predict: str, ground_truth: str) -> float: answer = extract_boxed_content(predict) return 1.0 if grade_answer(answer, ground_truth) else 0.0 + def calculate_tool_reward(predict: str, weight: float = 0.05, cap: int = 4) -> float: if not predict: return 0.0 @@ -125,10 +138,15 @@ def calculate_tool_reward(predict: str, weight: float = 0.05, cap: int = 4) -> f return capped_calls * weight -def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1, file_path: str = "") -> List[Dict[str, float]]: +def compute_score( + predicts: List[str], + ground_truths: List[str], + format_weight: float = 0.1, + file_path: str = "", +) -> List[Dict[str, float]]: results = [] - with open('test.json','w') as f: - json.dump(predicts,f,indent=4) + with open("test.json", "w") as f: + json.dump(predicts, f, indent=4) for i in tqdm(range(len(predicts)), desc=" - Parsing predictions"): questions = re.findall(r"(.*?)", predicts[i], re.DOTALL) answers = extract_boxed_content(predicts[i]) @@ -143,10 +161,27 @@ def compute_score(predicts: List[str], ground_truths: List[str], format_weight: results.append({"question": "", "answer": ""}) final_results = generate_results(results) - penalty = cluster_share_per_problem([result['question'] for result in final_results], distance_threshold=0.5) + penalty = cluster_share_per_problem( + [result["question"] for result in final_results], distance_threshold=0.5 + ) assert len(penalty) == len(final_results) scores = [] for i in tqdm(range(len(final_results)), desc=" - Calculating final scores"): - final_score = (min(final_results[i]["score"],1-final_results[i]["score"]) if final_results[i]['question'] else -1)-penalty[i]+calculate_tool_reward(predicts[i]) - scores.append({"overall": final_score,"format": 1 if final_results[i]['question'] else 0,"accuracy": penalty[i],"tool_reward": calculate_tool_reward(predicts[i])}) - return scores \ No newline at end of file + final_score = ( + ( + min(final_results[i]["score"], 1 - final_results[i]["score"]) + if final_results[i]["question"] + else -1 + ) + - penalty[i] + + calculate_tool_reward(predicts[i]) + ) + scores.append( + { + "overall": final_score, + "format": 1 if final_results[i]["question"] else 0, + "accuracy": penalty[i], + "tool_reward": calculate_tool_reward(predicts[i]), + } + ) + return scores diff --git a/Agent0/curriculum_train/examples/reward_function/math.py b/Agent0/curriculum_train/examples/reward_function/math.py index 1a8b675..410aac9 100644 --- a/Agent0/curriculum_train/examples/reward_function/math.py +++ b/Agent0/curriculum_train/examples/reward_function/math.py @@ -32,15 +32,20 @@ def accuracy_reward(predict: str, ground_truth: str) -> float: return 0.0 -def compute_score(predicts: List[str], ground_truths: List[str], format_weight: float = 0.1) -> List[Dict[str, float]]: +def compute_score( + predicts: List[str], ground_truths: List[str], format_weight: float = 0.1 +) -> List[Dict[str, float]]: scores = [] for predict, ground_truth in zip(predicts, ground_truths): - predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict) # handle qwen2.5vl-32b format + predict = re.sub( + r"\s*(<|>|/)\s*", r"\1", predict + ) # handle qwen2.5vl-32b format format_score = format_reward(predict) accuracy_score = accuracy_reward(predict, ground_truth) scores.append( { - "overall": (1 - format_weight) * accuracy_score + format_weight * format_score, + "overall": (1 - format_weight) * accuracy_score + + format_weight * format_score, "format": format_score, "accuracy": accuracy_score, } diff --git a/Agent0/curriculum_train/examples/reward_function/r1v.py b/Agent0/curriculum_train/examples/reward_function/r1v.py index 204762f..5564226 100644 --- a/Agent0/curriculum_train/examples/reward_function/r1v.py +++ b/Agent0/curriculum_train/examples/reward_function/r1v.py @@ -27,7 +27,9 @@ def format_reward(predict: str) -> float: def accuracy_reward(predict: str, ground_truth: str) -> float: try: content_match = re.search(r"(.*?)", predict) - given_answer = content_match.group(1).strip() if content_match else predict.strip() + given_answer = ( + content_match.group(1).strip() if content_match else predict.strip() + ) if grade_answer(given_answer, ground_truth.strip()): return 1.0 @@ -37,7 +39,9 @@ def accuracy_reward(predict: str, ground_truth: str) -> float: return 0.0 -def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]: +def compute_score( + predict: str, ground_truth: str, format_weight: float = 0.5 +) -> Dict[str, float]: format_score = format_reward(predict) accuracy_score = accuracy_reward(predict, ground_truth) return { diff --git a/Agent0/curriculum_train/question_evaluate/evaluate.py b/Agent0/curriculum_train/question_evaluate/evaluate.py index b7106cc..6574e98 100644 --- a/Agent0/curriculum_train/question_evaluate/evaluate.py +++ b/Agent0/curriculum_train/question_evaluate/evaluate.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -''' +""" Description: This script evaluates generated answers against golden answers for a set of questions. It uses vLLM for efficient generation and a robust, timed grading mechanism to score the results. @@ -19,7 +19,7 @@ Example Usage (in a shell script): # This would run the script for GPU 0, with a specific model and save name. CUDA_VISIBLE_DEVICES=0 python evaluate.py --model "Qwen/Qwen3-4B-Base" --suffix 0 --save_name "my_experiment" & -''' +""" import json import vllm @@ -32,19 +32,42 @@ # --- Argument Parsing --- parser = argparse.ArgumentParser(description="Evaluate generated questions using vLLM.") -parser.add_argument("--model", type=str, default="Qwen/Qwen3-4B-Base", help="Path to the model in Hugging Face format.") -parser.add_argument("--num_samples", type=int, default=9, help="Number of candidate answers to generate per question (n).") -parser.add_argument("--suffix", type=str, default="0", help="A unique suffix for file naming, often the GPU index.") -parser.add_argument("--save_name", type=str, required=True, help="A base name for input and output files.") +parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen3-4B-Base", + help="Path to the model in Hugging Face format.", +) +parser.add_argument( + "--num_samples", + type=int, + default=9, + help="Number of candidate answers to generate per question (n).", +) +parser.add_argument( + "--suffix", + type=str, + default="0", + help="A unique suffix for file naming, often the GPU index.", +) +parser.add_argument( + "--save_name", + type=str, + required=True, + help="A base name for input and output files.", +) args = parser.parse_args() # --- Constants and Paths --- STORAGE_PATH = os.getenv("STORAGE_PATH", "") INPUT_FILE = f"{STORAGE_PATH}/generated_question/{args.save_name}_{args.suffix}.json" -OUTPUT_FILE = f"{STORAGE_PATH}/generated_question/{args.save_name}_{args.suffix}_results.json" +OUTPUT_FILE = ( + f"{STORAGE_PATH}/generated_question/{args.save_name}_{args.suffix}_results.json" +) + # --- Timeout-Protected Grading Function --- -@stopit.threading_timeoutable(default='TIMED_OUT') +@stopit.threading_timeoutable(default="TIMED_OUT") def grade_answer_with_timeout(res1, res2): """ Wraps the mathruler 'grade_answer' function with a timeout. @@ -53,6 +76,7 @@ def grade_answer_with_timeout(res1, res2): # The actual timeout value is passed as a keyword argument on each call. return grade_answer(res1, res2) + # --- Main Script Logic --- # 1. Load and Prepare Data @@ -67,7 +91,7 @@ def grade_answer_with_timeout(res1, res2): exit() # Filter data into questions that need processing -correct_data = [item for item in data if item.get('score') == 0] +correct_data = [item for item in data if item.get("score") == 0] if not correct_data: print(f"[{args.suffix}] No new questions to process (score=0). Exiting.") # Create an empty results file to signal completion @@ -99,12 +123,29 @@ def grade_answer_with_timeout(res1, res2): # 3. Generate Responses print(f"[{args.suffix}] Generating {args.num_samples} samples for each question...") -chats = [[{"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},{"role": "user", "content": q}] for q in questions] +chats = [ + [ + { + "role": "system", + "content": "Please reason step by step, and put your final answer within \\boxed{}.", + }, + {"role": "user", "content": q}, + ] + for q in questions +] if tokenizer.chat_template: - prompts = [tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, add_special_tokens=True) for chat in chats] + prompts = [ + tokenizer.apply_chat_template( + chat, tokenize=False, add_generation_prompt=True, add_special_tokens=True + ) + for chat in chats + ] else: - prompts = ["system: " + chat[0]["content"] + '\n' + "user: " + chat[1]["content"] for chat in chats] + prompts = [ + "system: " + chat[0]["content"] + "\n" + "user: " + chat[1]["content"] + for chat in chats + ] responses = model.generate(prompts, sampling_params=sample_params, use_tqdm=True) print(f"[{args.suffix}] Generation complete.") @@ -116,10 +157,12 @@ def grade_answer_with_timeout(res1, res2): try: # Extract the boxed content from all generated samples results = [extract_boxed_content(output.text) for output in response.outputs] - results = [res for res in results if res] # Filter out None/empty results + results = [res for res in results if res] # Filter out None/empty results if not results: - print(f"[{args.suffix}] WARNING: No valid boxed answers found for question: '{question[:50]}...'") + print( + f"[{args.suffix}] WARNING: No valid boxed answers found for question: '{question[:50]}...'" + ) continue answer_counts = {} @@ -127,26 +170,32 @@ def grade_answer_with_timeout(res1, res2): matched = False for existing_answer in answer_counts: # OPTIMIZATION: Perform cheap string comparisons first. - if result == existing_answer or ('no ' in result.lower() and 'no ' in existing_answer.lower()): + if result == existing_answer or ( + "no " in result.lower() and "no " in existing_answer.lower() + ): answer_counts[existing_answer] += 1 matched = True break - + # If cheap checks fail, use the expensive, timed grader. # Check both directions (A vs B and B vs A). match_1 = grade_answer_with_timeout(result, existing_answer, timeout=10) - if match_1 == 'TIMED_OUT': - print(f"[{args.suffix}] GRADER TIMEOUT on: '{result[:30]}...' vs '{existing_answer[:30]}...'") - continue # Skip to the next existing_answer - + if match_1 == "TIMED_OUT": + print( + f"[{args.suffix}] GRADER TIMEOUT on: '{result[:30]}...' vs '{existing_answer[:30]}...'" + ) + continue # Skip to the next existing_answer + if match_1: answer_counts[existing_answer] += 1 matched = True break match_2 = grade_answer_with_timeout(existing_answer, result, timeout=10) - if match_2 == 'TIMED_OUT': - print(f"[{args.suffix}] GRADER TIMEOUT on: '{existing_answer[:30]}...' vs '{result[:30]}...'") + if match_2 == "TIMED_OUT": + print( + f"[{args.suffix}] GRADER TIMEOUT on: '{existing_answer[:30]}...' vs '{result[:30]}...'" + ) continue if match_2: @@ -166,23 +215,33 @@ def grade_answer_with_timeout(res1, res2): score = max_count / len(results) # Skip certain question types that are hard to grade automatically - if "่ฏๆ˜Ž" in question or 'box' in question.lower() or 'text' in majority_answer.lower(): + if ( + "่ฏๆ˜Ž" in question + or "box" in question.lower() + or "text" in majority_answer.lower() + ): continue - results_all.append({ - "question": question, - "answer": majority_answer, - "score": score, - 'results': results - }) + results_all.append( + { + "question": question, + "answer": majority_answer, + "score": score, + "results": results, + } + ) except Exception as e: - print(f"[{args.suffix}] CRITICAL ERROR processing question '{question[:50]}...': {e}") + print( + f"[{args.suffix}] CRITICAL ERROR processing question '{question[:50]}...': {e}" + ) continue # 5. Save Final Results -print(f"[{args.suffix}] Processed {len(results_all)} questions. Saving results to: {OUTPUT_FILE}") +print( + f"[{args.suffix}] Processed {len(results_all)} questions. Saving results to: {OUTPUT_FILE}" +) with open(OUTPUT_FILE, "w") as f: json.dump(results_all, f, indent=4) -print(f"[{args.suffix}] Script finished.") \ No newline at end of file +print(f"[{args.suffix}] Script finished.") diff --git a/Agent0/curriculum_train/question_evaluate/upload.py b/Agent0/curriculum_train/question_evaluate/upload.py index 95afd83..7b02e91 100644 --- a/Agent0/curriculum_train/question_evaluate/upload.py +++ b/Agent0/curriculum_train/question_evaluate/upload.py @@ -21,9 +21,11 @@ datas = [] for i in range(8): - file_path = f'{STORAGE_PATH}/generated_question/{args.experiment_name}_{i}_results.json' + file_path = ( + f"{STORAGE_PATH}/generated_question/{args.experiment_name}_{i}_results.json" + ) try: - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) datas.extend(data) except FileNotFoundError: @@ -32,16 +34,18 @@ print("Cleaning up temporary JSON files...", file=sys.stderr) for i in range(8): - file_path = f'{STORAGE_PATH}/generated_question/{args.experiment_name}_{i}_results.json' + file_path = ( + f"{STORAGE_PATH}/generated_question/{args.experiment_name}_{i}_results.json" + ) try: os.remove(file_path) except FileNotFoundError: pass filtered_datas = [ - {'problem': data['question'], 'answer': data['answer'], 'score': data['score']} + {"problem": data["question"], "answer": data["answer"], "score": data["score"]} for data in datas - if args.min_score <= data.get('score', 0) <= args.max_score and data.get('answer') + if args.min_score <= data.get("score", 0) <= args.max_score and data.get("answer") ] print(f"Filtered down to {len(filtered_datas)} samples.", file=sys.stderr) @@ -53,9 +57,9 @@ os.makedirs(save_dir, exist_ok=True) save_path = f"{save_dir}/train.parquet" - + train_dataset.to_parquet(save_path) - + print(save_path) else: - print("Warning: No data to save after filtering.", file=sys.stderr) \ No newline at end of file + print("Warning: No data to save after filtering.", file=sys.stderr) diff --git a/Agent0/curriculum_train/question_generate/question_generate.py b/Agent0/curriculum_train/question_generate/question_generate.py index dee5433..e433573 100644 --- a/Agent0/curriculum_train/question_generate/question_generate.py +++ b/Agent0/curriculum_train/question_generate/question_generate.py @@ -8,24 +8,26 @@ import json import regex as re import os + STORAGE_PATH = os.getenv("STORAGE_PATH") + def extract_boxed(text): results, i = [], 0 - prefix = r'\boxed{' + prefix = r"\boxed{" plen = len(prefix) while True: start = text.find(prefix, i) if start == -1: - break # no more \boxed{โ€ฆ} + break # no more \boxed{โ€ฆ} j = start + plen depth = 1 while j < len(text) and depth: - if text[j] == '{': + if text[j] == "{": depth += 1 - elif text[j] == '}': + elif text[j] == "}": depth -= 1 j += 1 @@ -34,6 +36,7 @@ def extract_boxed(text): return results + def get_response_mask(response_ids, eos_token_id, dtype): batch_size, seq_len = response_ids.shape mask = torch.ones((batch_size, seq_len), dtype=dtype) @@ -44,6 +47,7 @@ def get_response_mask(response_ids, eos_token_id, dtype): break return mask + def main(args): tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token is None: @@ -76,26 +80,23 @@ def main(args): r"\boxed{final_answer}" "\n\n" "Do NOT output anything elseโ€”no explanations, no extra markup." - ) + ), }, { "role": "user", "content": ( "Generate one new, challenging reasoning question now. " "Remember to format the output exactly as instructed." - ) - } + ), + }, ] if tokenizer.chat_template: prompt = tokenizer.apply_chat_template( - chat, - tokenize=False, - add_generation_prompt=True, - add_special_tokens=True + chat, tokenize=False, add_generation_prompt=True, add_special_tokens=True ) else: - prompt = "system: " + chat[0]["content"] + '\n' + "user: " + chat[1]["content"] + prompt = "system: " + chat[0]["content"] + "\n" + "user: " + chat[1]["content"] sample_params = vllm.SamplingParams( max_tokens=4096, temperature=1.0, @@ -104,8 +105,10 @@ def main(args): stop_token_ids=[tokenizer.eos_token_id], ) - completions: List[RequestOutput] = model.generate([prompt]*args.num_samples, sampling_params=sample_params) - results=[] + completions: List[RequestOutput] = model.generate( + [prompt] * args.num_samples, sampling_params=sample_params + ) + results = [] for completion in completions: response = completion.outputs[0].text try: @@ -120,15 +123,22 @@ def main(args): results.append({"question": response, "answer": "", "score": -1}) except: results.append({"question": response, "answer": "", "score": -1}) - with open(f"{STORAGE_PATH}/generated_question/{args.save_name}_{args.suffix}.json", "w") as f: + with open( + f"{STORAGE_PATH}/generated_question/{args.save_name}_{args.suffix}.json", "w" + ) as f: json.dump(results, f, indent=4) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="Qwen/Qwen3-4B") - parser.add_argument("--num_samples", type=int, default=1250, help="Number of samples to generate") - parser.add_argument("--suffix", type=str, default="", help="Suffix to add to the output file") + parser.add_argument( + "--num_samples", type=int, default=1250, help="Number of samples to generate" + ) + parser.add_argument( + "--suffix", type=str, default="", help="Suffix to add to the output file" + ) parser.add_argument("--save_name", type=str, default="", help="") args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/Agent0/curriculum_train/requirements.txt b/Agent0/curriculum_train/requirements.txt index b63d664..fcdb8fe 100644 --- a/Agent0/curriculum_train/requirements.txt +++ b/Agent0/curriculum_train/requirements.txt @@ -38,7 +38,6 @@ fastapi==0.115.12 fastapi-cli==0.0.7 fastrlock==0.8.3 filelock==3.18.0 -flash_attn==2.7.4.post1 Flask==3.1.1 fonttools==4.58.2 frozenlist==1.7.0 diff --git a/Agent0/curriculum_train/scripts/model_merger.py b/Agent0/curriculum_train/scripts/model_merger.py index 4f4dd3d..df511a6 100644 --- a/Agent0/curriculum_train/scripts/model_merger.py +++ b/Agent0/curriculum_train/scripts/model_merger.py @@ -53,12 +53,21 @@ def upload_model_to_huggingface(local_path: str, remote_path: str): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model") - parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") + parser.add_argument( + "--local_dir", required=True, type=str, help="The path for your saved model" + ) + parser.add_argument( + "--hf_upload_path", + default=False, + type=str, + help="The path of the huggingface repo to upload", + ) args = parser.parse_args() local_dir: str = args.local_dir - assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface." + assert not local_dir.endswith( + "huggingface" + ), "The local_dir should not end with huggingface." # copy rank zero to find the shape of (dp, fsdp) rank = 0 @@ -71,7 +80,9 @@ def upload_model_to_huggingface(local_path: str, remote_path: str): assert world_size, "No model file with the proper format." - rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") + rank0_weight_path = os.path.join( + local_dir, f"model_world_size_{world_size}_rank_{rank}.pt" + ) state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False) pivot_key = sorted(state_dict.keys())[0] weight = state_dict[pivot_key] @@ -87,7 +98,10 @@ def upload_model_to_huggingface(local_path: str, remote_path: str): print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}." + assert mesh_dim_names in ( + ("fsdp",), + ("ddp", "fsdp"), + ), f"Unsupported mesh_dim_names {mesh_dim_names}." if "tp" in mesh_dim_names: # fsdp * tp @@ -104,7 +118,9 @@ def upload_model_to_huggingface(local_path: str, remote_path: str): model_state_dict_lst.extend([""] * (total_shards - 1)) def process_one_shard(rank, model_state_dict_lst): - model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") + model_path = os.path.join( + local_dir, f"model_world_size_{world_size}_rank_{rank}.pt" + ) state_dict = torch.load(model_path, map_location="cpu", weights_only=False) model_state_dict_lst[rank] = state_dict return state_dict @@ -174,7 +190,9 @@ def process_one_shard(rank, model_state_dict_lst): raise NotImplementedError(f"Unknown architecture {architectures}.") with torch.device("meta"): - model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16) + model: PreTrainedModel = AutoClass.from_config( + config, torch_dtype=torch.bfloat16 + ) assert isinstance(model, PreTrainedModel) model.to_empty(device="cpu") diff --git a/Agent0/curriculum_train/verl/__init__.py b/Agent0/curriculum_train/verl/__init__.py index cf49f90..382fa23 100644 --- a/Agent0/curriculum_train/verl/__init__.py +++ b/Agent0/curriculum_train/verl/__init__.py @@ -27,6 +27,8 @@ if os.getenv("USE_MODELSCOPE_HUB", "0").lower() in ["true", "y", "1"]: # Patch hub to download models from modelscope to speed up. if not is_package_available("modelscope"): - raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope`.") + raise ImportError( + "You are using the modelscope hub, please install modelscope by `pip install modelscope`." + ) patch_hub() diff --git a/Agent0/curriculum_train/verl/protocol.py b/Agent0/curriculum_train/verl/protocol.py index 65d48be..9c76539 100644 --- a/Agent0/curriculum_train/verl/protocol.py +++ b/Agent0/curriculum_train/verl/protocol.py @@ -45,7 +45,9 @@ __all__ = ["DataProto", "union_tensor_dict"] -def pad_dataproto_to_divisor(data: "DataProto", size_divisor: int) -> Tuple["DataProto", int]: +def pad_dataproto_to_divisor( + data: "DataProto", size_divisor: int +) -> Tuple["DataProto", int]: """Pad a DataProto to size divisible by size_divisor Args: @@ -89,7 +91,9 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten ) for key in tensor_dict2.keys(): - if key in tensor_dict1 and not torch.equal(tensor_dict1[key], tensor_dict2[key]): + if key in tensor_dict1 and not torch.equal( + tensor_dict1[key], tensor_dict2[key] + ): raise ValueError(f"Key already exists: {key}.") tensor_dict1[key] = tensor_dict2[key] @@ -97,7 +101,9 @@ def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> Ten return tensor_dict1 -def union_numpy_dict(tensor_dict1: Dict[str, NDArray], tensor_dict2: Dict[str, NDArray]) -> Dict[str, NDArray]: +def union_numpy_dict( + tensor_dict1: Dict[str, NDArray], tensor_dict2: Dict[str, NDArray] +) -> Dict[str, NDArray]: for key in tensor_dict2.keys(): if key in tensor_dict1: assert isinstance(tensor_dict2[key], np.ndarray) @@ -137,9 +143,13 @@ def fold_batch_dim(data: "DataProto", new_batch_size: int): tensor.auto_batch_size_(batch_dims=1) for key, value in non_tensor.items(): - non_tensor[key] = np.reshape(value, newshape=(new_batch_size, -1, *value.shape[1:])) + non_tensor[key] = np.reshape( + value, newshape=(new_batch_size, -1, *value.shape[1:]) + ) - return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info) + return DataProto( + batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info + ) def collate_fn(data_items: list["DataProtoItem"]): @@ -151,7 +161,9 @@ def collate_fn(data_items: list["DataProtoItem"]): batch = torch.stack(batch).contiguous() non_tensor_batch = batch_collate(non_tensor_batch) - non_tensor_batch = {key: np.array(value, dtype=object) for key, value in non_tensor_batch.items()} + non_tensor_batch = { + key: np.array(value, dtype=object) for key, value in non_tensor_batch.items() + } return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) @@ -187,11 +199,19 @@ def __len__(self) -> int: else: return 0 - def __getitem__(self, item: Union[int, slice]) -> Union["DataProto", "DataProtoItem"]: + def __getitem__( + self, item: Union[int, slice] + ) -> Union["DataProto", "DataProtoItem"]: tensor_data = self.batch[item] - non_tensor_data = {key: value[item] for key, value in self.non_tensor_batch.items()} + non_tensor_data = { + key: value[item] for key, value in self.non_tensor_batch.items() + } return_type = DataProto if isinstance(item, slice) else DataProtoItem - return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) + return return_type( + batch=tensor_data, + non_tensor_batch=non_tensor_data, + meta_info=self.meta_info, + ) def __getstate__(self) -> Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]: buffer = io.BytesIO() @@ -203,7 +223,9 @@ def __getstate__(self) -> Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]: buffer_bytes = buffer.getvalue() return buffer_bytes, self.non_tensor_batch, self.meta_info - def __setstate__(self, data: Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]) -> None: + def __setstate__( + self, data: Tuple[bytes, Dict[str, NDArray], Dict[str, Any]] + ) -> None: batch_deserialized_bytes, non_tensor_batch, meta_info = data batch_deserialized = io.BytesIO(batch_deserialized_bytes) batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu") @@ -247,11 +269,15 @@ def check_consistency(self): if self.batch is not None and len(self.non_tensor_batch) != 0: # TODO: we can actually lift this restriction if needed - assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty." + assert ( + len(self.batch.batch_size) == 1 + ), "only support num_batch_dims=1 when non_tensor_batch is not empty." batch_size = self.batch.batch_size[0] for key, value in self.non_tensor_batch.items(): - assert len(value) == batch_size, f"key {key} length {len(value)} is not equal to bsz {batch_size}." + assert ( + len(value) == batch_size + ), f"key {key} length {len(value)} is not equal to bsz {batch_size}." @classmethod def from_single_dict( @@ -268,7 +294,9 @@ def from_single_dict( else: raise ValueError(f"Unsupported type in data {type(value)}") - return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) + return DataProto.from_dict( + tensors=tensors, non_tensors=non_tensors, meta_info=meta_info + ) @classmethod def from_dict( @@ -285,7 +313,9 @@ def from_dict( assert len(tensors) > 0, "tensors must not be empty" assert num_batch_dims > 0, "num_batch_dims must be greater than zero" if non_tensors is not None: - assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None." + assert ( + num_batch_dims == 1 + ), "only support num_batch_dims=1 when non_tensors is not None." meta_info = meta_info or {} non_tensors = non_tensors or {} @@ -347,7 +377,11 @@ def select( sub_batch = self.batch if non_tensor_batch_keys is not None: - non_tensor_batch = {k: v for k, v in self.non_tensor_batch.items() if k in non_tensor_batch_keys} + non_tensor_batch = { + k: v + for k, v in self.non_tensor_batch.items() + if k in non_tensor_batch_keys + } else: non_tensor_batch = self.non_tensor_batch @@ -355,14 +389,18 @@ def select( non_tensor_batch = copy.deepcopy(non_tensor_batch) if meta_info_keys is not None: - sub_meta_info = {k: v for k, v in self.meta_info.items() if k in meta_info_keys} + sub_meta_info = { + k: v for k, v in self.meta_info.items() if k in meta_info_keys + } else: sub_meta_info = self.meta_info if deepcopy: sub_meta_info = copy.deepcopy(sub_meta_info) - return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) + return DataProto( + batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info + ) def pop( self, @@ -395,10 +433,14 @@ def pop( for key in meta_info_keys: meta_info[key] = self.meta_info.pop(key) - return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) + return DataProto.from_dict( + tensors=tensors, non_tensors=non_tensors, meta_info=meta_info + ) def rename( - self, old_keys: Optional[Union[str, List[str]]] = None, new_keys: Optional[Union[str, List[str]]] = None + self, + old_keys: Optional[Union[str, List[str]]] = None, + new_keys: Optional[Union[str, List[str]]] = None, ) -> "DataProto": """ Note that this function only rename the key in the batch @@ -411,7 +453,9 @@ def validate_input(keys): elif isinstance(keys, list): pass else: - raise TypeError(f"keys must be a list or a string, but got {type(keys)}") + raise TypeError( + f"keys must be a list or a string, but got {type(keys)}" + ) return keys old_keys = validate_input(old_keys) @@ -440,12 +484,18 @@ def union(self, other: "DataProto") -> "DataProto": DataProto: the DataProto after union """ self.batch = union_tensor_dict(self.batch, other.batch) - self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) + self.non_tensor_batch = union_numpy_dict( + self.non_tensor_batch, other.non_tensor_batch + ) self.meta_info = union_two_dict(self.meta_info, other.meta_info) return self def make_iterator( - self, mini_batch_size: int, epochs: int, seed: int = None, dataloader_kwargs: Dict[str, Any] = None + self, + mini_batch_size: int, + epochs: int, + seed: int = None, + dataloader_kwargs: Dict[str, Any] = None, ): """Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details. @@ -461,7 +511,9 @@ def make_iterator( Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is ``self.batch.batch_size * epochs // mini_batch_size`` """ - assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" + assert ( + self.batch.batch_size[0] % mini_batch_size == 0 + ), f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" # we can directly create a dataloader from TensorDict if dataloader_kwargs is None: dataloader_kwargs = {} @@ -474,7 +526,11 @@ def make_iterator( assert isinstance(dataloader_kwargs, Dict) train_dataloader = DataLoader( - dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs + dataset=self, + batch_size=mini_batch_size, + collate_fn=collate_fn, + generator=generator, + **dataloader_kwargs, ) def get_data(): @@ -494,9 +550,9 @@ def chunk(self, chunks: int) -> List["DataProto"]: Returns: List[DataProto]: a list of DataProto after splitting """ - assert len(self) % chunks == 0, ( - f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." - ) + assert ( + len(self) % chunks == 0 + ), f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." if self.batch is not None: batch_lst = self.batch.chunk(chunks=chunks, dim=0) else: @@ -513,7 +569,11 @@ def chunk(self, chunks: int) -> List["DataProto"]: output = [] for i in range(chunks): output.append( - DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info) + DataProto( + batch=batch_lst[i], + non_tensor_batch=non_tensor_batch_lst[i], + meta_info=self.meta_info, + ) ) return output @@ -543,7 +603,11 @@ def concat(data: List["DataProto"]) -> "DataProto": for key, value in non_tensor_batch.items(): non_tensor_batch[key] = np.concatenate(value, axis=0) - return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) + return DataProto( + batch=new_batch, + non_tensor_batch=non_tensor_batch, + meta_info=data[0].meta_info, + ) def reorder(self, indices: torch.Tensor) -> None: """ @@ -551,7 +615,9 @@ def reorder(self, indices: torch.Tensor) -> None: """ indices_np = indices.detach().numpy() self.batch = self.batch[indices] - self.non_tensor_batch = {key: value[indices_np] for key, value in self.non_tensor_batch.items()} + self.non_tensor_batch = { + key: value[indices_np] for key, value in self.non_tensor_batch.items() + } def repeat(self, repeat_times: int = 2, interleave: bool = True) -> "DataProto": """ @@ -568,12 +634,15 @@ def repeat(self, repeat_times: int = 2, interleave: bool = True) -> "DataProto": if interleave: # Interleave the data repeated_tensors = { - key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + key: tensor.repeat_interleave(repeat_times, dim=0) + for key, tensor in self.batch.items() } else: # Stack the data repeated_tensors = { - key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) + key: tensor.unsqueeze(0) + .expand(repeat_times, *tensor.shape) + .reshape(-1, *tensor.shape[1:]) for key, tensor in self.batch.items() } @@ -589,7 +658,9 @@ def repeat(self, repeat_times: int = 2, interleave: bool = True) -> "DataProto": if interleave: repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0) else: - repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1)) + repeated_non_tensor_batch[key] = np.tile( + value, (repeat_times,) + (1,) * (value.ndim - 1) + ) return DataProto( batch=repeated_batch, @@ -631,7 +702,9 @@ def dispatch_fn(x, i, chunks): return x.chunk(chunks=chunks)[i] arg_future = DataProtoFuture( - collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures + collect_fn=self.collect_fn, + dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), + futures=self.futures, ) arg_future_lst.append(arg_future) return arg_future_lst @@ -649,7 +722,10 @@ def get(self): def allgather_dict_tensors( - tensors: Union[Dict[str, torch.Tensor], TensorDict], size: int, group: ProcessGroup, dim: int = 0 + tensors: Union[Dict[str, torch.Tensor], TensorDict], + size: int, + group: ProcessGroup, + dim: int = 0, ) -> Union[Dict[str, torch.Tensor], TensorDict]: """ TODO: optimize this. @@ -681,9 +757,16 @@ def all_gather_data_proto(data: DataProto, size: int, group: ProcessGroup) -> No # Note that this is an inplace operator just like torch.distributed.all_gather prev_device = data.batch.device data.batch = data.batch.cuda(device=torch.cuda.current_device()) - data.batch = allgather_dict_tensors(data.batch.contiguous(), size=size, group=group, dim=0) + data.batch = allgather_dict_tensors( + data.batch.contiguous(), size=size, group=group, dim=0 + ) data.batch = data.batch.to(prev_device) # all gather non_tensor_batch all_non_tensor_batch = [None for _ in range(size)] - torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group) - data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch} + torch.distributed.all_gather_object( + all_non_tensor_batch, data.non_tensor_batch, group=group + ) + data.non_tensor_batch = { + k: np.concatenate([d[k] for d in all_non_tensor_batch]) + for k in data.non_tensor_batch + } diff --git a/Agent0/curriculum_train/verl/single_controller/base/decorator.py b/Agent0/curriculum_train/verl/single_controller/base/decorator.py index b0e85a3..1091ddd 100644 --- a/Agent0/curriculum_train/verl/single_controller/base/decorator.py +++ b/Agent0/curriculum_train/verl/single_controller/base/decorator.py @@ -93,31 +93,45 @@ def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs): assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size for value in kwargs.values(): - assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size + assert ( + isinstance(value, (tuple, list)) and len(value) == worker_group.world_size + ) return args, kwargs -def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]: +def collect_dp_compute( + worker_group: "WorkerGroup", outputs: List[DataProto] +) -> List[DataProto]: assert len(outputs) == worker_group.world_size return outputs def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs): - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs) + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto( + worker_group.world_size, *args, **kwargs + ) return splitted_args, splitted_kwargs -def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args, **kwargs): +def dispatch_dp_compute_data_proto_with_func( + worker_group: "WorkerGroup", *args, **kwargs +): assert type(args[0]) is FunctionType # NOTE: The first one args is a function! - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto( + worker_group.world_size, *args[1:], **kwargs + ) splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args return splitted_args_with_func, splitted_kwargs -def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto: +def collect_dp_compute_data_proto( + worker_group: "WorkerGroup", outputs: List[DataProto] +) -> DataProto: for output in outputs: - assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}" + assert isinstance( + output, (DataProto, ray.ObjectRef) + ), f"Expect a DataProto, but got {type(output)}" outputs = collect_dp_compute(worker_group, outputs) return _concat_data_proto_or_future(outputs) @@ -165,18 +179,26 @@ def get_predefined_execute_fn(execute_mode: Execute): return predefined_execute_mode_fn[execute_mode] -def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]): - assert isinstance(dispatch_mode, (Dispatch, dict)), ( - f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" - ) +def _check_dispatch_mode( + dispatch_mode: Union[ + Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType] + ], +): + assert isinstance( + dispatch_mode, (Dispatch, dict) + ), f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" if isinstance(dispatch_mode, dict): necessary_keys = ["dispatch_fn", "collect_fn"] for key in necessary_keys: - assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" + assert ( + key in dispatch_mode + ), f"key {key} should be in dispatch_mode if it is a dictionary" def _check_execute_mode(execute_mode: Execute): - assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" + assert isinstance( + execute_mode, Execute + ), f"execute_mode must be a Execute. Got {execute_mode}" def _materialize_futures(*args, **kwargs): @@ -195,7 +217,12 @@ def _materialize_futures(*args, **kwargs): return new_args, kwargs -def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): +def register( + dispatch_mode=Dispatch.ALL_TO_ALL, + execute_mode=Execute.ALL, + blocking=True, + materialize_futures=True, +): _check_dispatch_mode(dispatch_mode=dispatch_mode) _check_execute_mode(execute_mode=execute_mode) @@ -206,7 +233,11 @@ def inner(*args, **kwargs): args, kwargs = _materialize_futures(*args, **kwargs) return func(*args, **kwargs) - attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} + attrs = { + "dispatch_mode": dispatch_mode, + "execute_mode": execute_mode, + "blocking": blocking, + } setattr(inner, MAGIC_ATTR, attrs) return inner diff --git a/Agent0/curriculum_train/verl/single_controller/base/worker.py b/Agent0/curriculum_train/verl/single_controller/base/worker.py index 9ecffca..8f456e3 100644 --- a/Agent0/curriculum_train/verl/single_controller/base/worker.py +++ b/Agent0/curriculum_train/verl/single_controller/base/worker.py @@ -78,7 +78,10 @@ def __init__(self, store) -> None: self._store = store def to_dict(self): - return {f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) for key in WorkerMeta.keys} + return { + f"_{key.lower()}": self._store.get(f"_{key.lower()}", None) + for key in WorkerMeta.keys + } # we assume that in each WorkerGroup, there is a Master Worker @@ -105,8 +108,13 @@ def __new__(cls, *args, **kwargs): worker_group_prefix = os.getenv("WG_PREFIX", None) # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init - if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: - instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) + if ( + None not in [rank, worker_group_prefix] + and "ActorClass(" not in cls.__name__ + ): + instance._configure_before_init( + f"{worker_group_prefix}_register_center", int(rank) + ) return instance @@ -119,7 +127,9 @@ def _configure_before_init(self, register_center_name: str, rank: int): "MASTER_ADDR": master_addr, "MASTER_PORT": master_port, } - self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info) + self.register_center = create_worker_group_register_center( + name=register_center_name, info=rank_zero_info + ) os.environ.update(rank_zero_info) def __init__(self, cuda_visible_devices=None) -> None: @@ -169,7 +179,9 @@ def _configure_with_meta(self, meta: WorkerMeta): os.environ[key] = str(val) os.environ["REDIS_STORE_SERVER_HOST"] = ( - str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" + str(self._master_addr).replace("[", "").replace("]", "") + if self._master_addr + else "" ) def get_master_addr_port(self): diff --git a/Agent0/curriculum_train/verl/single_controller/base/worker_group.py b/Agent0/curriculum_train/verl/single_controller/base/worker_group.py index 8648fbf..4e61b64 100644 --- a/Agent0/curriculum_train/verl/single_controller/base/worker_group.py +++ b/Agent0/curriculum_train/verl/single_controller/base/worker_group.py @@ -21,14 +21,22 @@ import time from typing import Any, Callable, Dict, List, Optional -from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn +from .decorator import ( + MAGIC_ATTR, + Dispatch, + get_predefined_dispatch_fn, + get_predefined_execute_fn, +) class ResourcePool: """The resource pool with meta info such as world size.""" def __init__( - self, process_on_nodes: Optional[Any] = None, max_colocate_count: int = 10, n_gpus_per_node: int = 8 + self, + process_on_nodes: Optional[Any] = None, + max_colocate_count: int = 10, + n_gpus_per_node: int = 8, ) -> None: if process_on_nodes is None: process_on_nodes = [] @@ -53,12 +61,15 @@ def store(self): def local_world_size_list(self) -> List[int]: nested_local_world_size_list = [ - [local_world_size for _ in range(local_world_size)] for local_world_size in self._store + [local_world_size for _ in range(local_world_size)] + for local_world_size in self._store ] return [item for row in nested_local_world_size_list for item in row] def local_rank_list(self) -> List[int]: - nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] # noqa: C416 + nested_local_rank_list = [ + [i for i in range(local_world_size)] for local_world_size in self._store + ] # noqa: C416 return [item for row in nested_local_rank_list for item in row] @@ -81,7 +92,9 @@ def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) while True: for worker in workers: if not is_alive(worker): - logging.warning(f"Worker {worker} is not alive, sending signal to main thread") + logging.warning( + f"Worker {worker} is not alive, sending signal to main thread" + ) signal.raise_signal(signal.SIGABRT) time.sleep(gap_time) @@ -108,7 +121,9 @@ def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: self._checker_thread: threading.Thread = None def _is_worker_alive(self, worker): - raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") + raise NotImplementedError( + "WorkerGroup._is_worker_alive called, should be implemented in derived class." + ) def _block_until_all_workers_alive(self) -> None: while True: @@ -123,7 +138,8 @@ def start_worker_aliveness_check(self, every_n_seconds=1) -> None: self._block_until_all_workers_alive() self._checker_thread = threading.Thread( - target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds) + target=check_workers_alive, + args=(self._workers, self._is_worker_alive, every_n_seconds), ) self._checker_thread.start() @@ -138,7 +154,9 @@ def _bind_worker_method(self, user_defined_cls, func_generator): for method_name in dir(user_defined_cls): try: method = getattr(user_defined_cls, method_name) - assert callable(method), f"{method_name} in {user_defined_cls} is not callable" + assert callable( + method + ), f"{method_name} in {user_defined_cls} is not callable" except Exception: # if it is a property, it will fail because Class doesn't have instance property continue @@ -146,8 +164,12 @@ def _bind_worker_method(self, user_defined_cls, func_generator): if hasattr(method, MAGIC_ATTR): # this method is decorated by register attribute = getattr(method, MAGIC_ATTR) - assert isinstance(attribute, Dict), f"attribute must be a dictionary. Got {type(attribute)}" - assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" + assert isinstance( + attribute, Dict + ), f"attribute must be a dictionary. Got {type(attribute)}" + assert ( + "dispatch_mode" in attribute + ), "attribute must contain dispatch_mode in its key" dispatch_mode = attribute["dispatch_mode"] execute_mode = attribute["execute_mode"] diff --git a/Agent0/curriculum_train/verl/single_controller/ray/__init__.py b/Agent0/curriculum_train/verl/single_controller/ray/__init__.py index 25b3141..3f099f1 100644 --- a/Agent0/curriculum_train/verl/single_controller/ray/__init__.py +++ b/Agent0/curriculum_train/verl/single_controller/ray/__init__.py @@ -12,7 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, create_colocated_worker_cls +from .base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, + create_colocated_worker_cls, +) -__all__ = ["RayClassWithInitArgs", "RayResourcePool", "RayWorkerGroup", "create_colocated_worker_cls"] +__all__ = [ + "RayClassWithInitArgs", + "RayResourcePool", + "RayWorkerGroup", + "create_colocated_worker_cls", +] diff --git a/Agent0/curriculum_train/verl/single_controller/ray/base.py b/Agent0/curriculum_train/verl/single_controller/ray/base.py index 9827312..aa0355f 100644 --- a/Agent0/curriculum_train/verl/single_controller/ray/base.py +++ b/Agent0/curriculum_train/verl/single_controller/ray/base.py @@ -25,7 +25,10 @@ from ray.experimental.state.api import get_actor from ray.util import list_named_actors from ray.util.placement_group import PlacementGroup, placement_group -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy +from ray.util.scheduling_strategies import ( + NodeAffinitySchedulingStrategy, + PlacementGroupSchedulingStrategy, +) from ..base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup from ..base.decorator import MAGIC_ATTR @@ -88,17 +91,25 @@ def __init__( self.pgs = None self.detached = detached - def get_placement_groups(self, strategy: str = "STRICT_PACK", name: Optional[str] = None) -> List[PlacementGroup]: + def get_placement_groups( + self, strategy: str = "STRICT_PACK", name: Optional[str] = None + ) -> List[PlacementGroup]: if self.pgs is not None: return self.pgs pg_name_prefix = ( - name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" + name + if name + else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" ) # print(f"pg_name_prefix = {pg_name_prefix}") pg_scheme = [ [ - {"CPU": self.max_colocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_colocate_count} + ( + {"CPU": self.max_colocate_count, "GPU": 1} + if self.use_gpu + else {"CPU": self.max_colocate_count} + ) for _ in range(process_count) ] for process_count in self._store @@ -107,7 +118,12 @@ def get_placement_groups(self, strategy: str = "STRICT_PACK", name: Optional[str lifetime = "detached" if self.detached else None pgs = [ - placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) + placement_group( + bundles=bundles, + strategy=strategy, + name=pg_name_prefix + str(idx), + lifetime=lifetime, + ) for idx, bundles in enumerate(pg_scheme) ] @@ -118,7 +134,9 @@ def get_placement_groups(self, strategy: str = "STRICT_PACK", name: Optional[str def extract_pg_from_exist( - resource_pools: Dict[str, RayResourcePool], src_role_names: List[str], resource_pool: RayResourcePool + resource_pools: Dict[str, RayResourcePool], + src_role_names: List[str], + resource_pool: RayResourcePool, ) -> List[PlacementGroup]: src_pgs = [ pg @@ -128,15 +146,19 @@ def extract_pg_from_exist( ] sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True) - sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True) + sorted_process_on_nodes = sorted( + [(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True + ) unsorted_pgs: List[Tuple[int, PlacementGroup]] = [] searching_idx = 0 for request_process, original_idx in sorted_process_on_nodes: - assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node" - assert request_process <= sorted_src_pgs[searching_idx].bundle_count, ( - f"requesting {request_process} processes, bundle count cannot satisfy" - ) + assert searching_idx < len( + sorted_src_pgs + ), f"no enough nodes for request: searching {searching_idx} th node" + assert ( + request_process <= sorted_src_pgs[searching_idx].bundle_count + ), f"requesting {request_process} processes, bundle count cannot satisfy" unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx])) searching_idx += 1 @@ -145,15 +167,21 @@ def extract_pg_from_exist( def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool: assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not" - assert rp1.max_colocate_count == rp2.max_colocate_count, ( - "Both RayResourcePool must has the same max_colocate_count" - ) - assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node" - assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool" + assert ( + rp1.max_colocate_count == rp2.max_colocate_count + ), "Both RayResourcePool must has the same max_colocate_count" + assert ( + rp1.n_gpus_per_node == rp2.n_gpus_per_node + ), "Both RayResourcePool must has the same n_gpus_per_node" + assert ( + rp1.detached == rp2.detached + ), "Detached ResourcePool cannot be merged with non-detached ResourcePool" new_store = rp1.store + rp2.store - merged = RayResourcePool(new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}") + merged = RayResourcePool( + new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}" + ) merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups() return merged @@ -182,15 +210,22 @@ def __call__( ) -> Any: if sharing_with is not None: target_node_id = ray.get(sharing_with.get_node_id.remote()) - cuda_visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) - options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)} + cuda_visible_devices = ray.get( + sharing_with.get_cuda_visible_devices.remote() + ) + options = { + "scheduling_strategy": NodeAffinitySchedulingStrategy( + node_id=target_node_id, soft=False + ) + } return self.cls.options(**options).remote( *self.args, cuda_visible_devices=cuda_visible_devices, **self.kwargs ) options = { "scheduling_strategy": PlacementGroupSchedulingStrategy( - placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx + placement_group=placement_group, + placement_group_bundle_index=placement_group_bundle_idx, ) } options.update(self._options) @@ -221,7 +256,9 @@ def __init__( ) -> None: super().__init__(resource_pool=resource_pool, **kwargs) self.ray_cls_with_init = ray_cls_with_init - self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix + self.name_prefix = ( + get_random_string(length=6) if name_prefix is None else name_prefix + ) if worker_names is not None: assert self._is_init_with_detached_workers @@ -231,7 +268,10 @@ def __init__( self._init_with_detached_workers(worker_names=worker_names) else: self._init_with_resource_pool( - resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached + resource_pool=resource_pool, + ray_cls_with_init=ray_cls_with_init, + bin_pack=bin_pack, + detached=detached, ) if ray_cls_with_init is not None: @@ -239,7 +279,11 @@ def __init__( def _is_worker_alive(self, worker: ActorHandle) -> bool: worker_state_dict = get_actor(worker._actor_id.hex()) - return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False + return ( + worker_state_dict.get("state", "undefined") == "ALIVE" + if worker_state_dict is not None + else False + ) def _init_with_detached_workers(self, worker_names: List[str]) -> None: workers = [ray.get_actor(name=name) for name in worker_names] @@ -247,7 +291,11 @@ def _init_with_detached_workers(self, worker_names: List[str]) -> None: self._world_size = len(worker_names) def _init_with_resource_pool( - self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, bin_pack: bool, detached: bool + self, + resource_pool: RayResourcePool, + ray_cls_with_init: RayClassWithInitArgs, + bin_pack: bool, + detached: bool, ): use_gpu = resource_pool.use_gpu @@ -264,7 +312,9 @@ def _init_with_resource_pool( rank = -1 local_world_size = resource_pool.store[0] for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)): - assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the " + assert ( + local_world_size <= pg.bundle_count + ), f"when generating for {self.name_prefix}, for the " for local_rank in range(local_world_size): rank += 1 @@ -282,18 +332,27 @@ def _init_with_resource_pool( env_vars["MASTER_PORT"] = self._master_port cia_name = type(ray_cls_with_init.cls).__name__ - match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" - cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" + match = re.search( + r"ActorClass\(([^)]+)\)", cia_name + ) # ray.remote(Obj) -> "ActorClass(Obj)" + cia_name = ( + match.group(1) if match else cia_name + ) # "ActorClass(Obj)" -> "Obj" name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5 - ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name}) + ray_cls_with_init.update_options( + {"runtime_env": {"env_vars": env_vars}, "name": name} + ) if detached: ray_cls_with_init.update_options({"lifetime": "detached"}) # create a worker worker = ray_cls_with_init( - placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus + placement_group=pg, + placement_group_bundle_idx=local_rank, + use_gpu=use_gpu, + num_gpus=num_gpus, ) self._workers.append(worker) self._worker_names.append(name) @@ -301,16 +360,26 @@ def _init_with_resource_pool( if rank == 0: register_center_actor = None for _ in range(120): - if f"{self.name_prefix}_register_center" not in list_named_actors(): + if ( + f"{self.name_prefix}_register_center" + not in list_named_actors() + ): time.sleep(1) else: - register_center_actor = ray.get_actor(f"{self.name_prefix}_register_center") + register_center_actor = ray.get_actor( + f"{self.name_prefix}_register_center" + ) break - assert register_center_actor is not None, ( - f"failed to get register_center_actor: {self.name_prefix}_register_center in {list_named_actors(all_namespaces=True)}" + assert ( + register_center_actor is not None + ), f"failed to get register_center_actor: {self.name_prefix}_register_center in {list_named_actors(all_namespaces=True)}" + rank_zero_info = ray.get( + register_center_actor.get_rank_zero_info.remote() + ) + self._master_addr, self._master_port = ( + rank_zero_info["MASTER_ADDR"], + rank_zero_info["MASTER_PORT"], ) - rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote()) - self._master_addr, self._master_port = rank_zero_info["MASTER_ADDR"], rank_zero_info["MASTER_PORT"] # print(f"rank_zero_info: {rank_zero_info}") # print(f"master_addr: {self._master_addr}, master_port: {self._master_port}") @@ -321,7 +390,10 @@ def worker_names(self): @classmethod def from_detached(cls, worker_names=None, ray_cls_with_init=None): worker_group = cls( - resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=None, worker_names=worker_names + resource_pool=None, + ray_cls_with_init=ray_cls_with_init, + name_prefix=None, + worker_names=worker_names, ) return worker_group @@ -346,7 +418,8 @@ def _rebind_actor_methods(worker_group, actor_name): new_worker_group_dict = {} for prefix in prefix_set: new_worker_group = self.from_detached( - worker_names=self._worker_names, ray_cls_with_init=self.ray_cls_with_init + worker_names=self._worker_names, + ray_cls_with_init=self.ray_cls_with_init, ) _rebind_actor_methods(new_worker_group, prefix) @@ -375,8 +448,12 @@ def execute_all_async(self, method_name: str, *args, **kwargs): # then we will send each element in the list to the corresponding worker. # print(f"execute_all_async: method {method_name}({args}, {kwargs})") length = len(self._workers) - if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): - if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()): + if all(isinstance(arg, list) for arg in args) and all( + isinstance(kwarg, list) for kwarg in kwargs.values() + ): + if all(len(arg) == length for arg in args) and all( + len(kwarg) == length for kwarg in kwargs.values() + ): # print(f"splitting args and kwargs into {length} shards") result = [] for i in range(length): @@ -386,7 +463,10 @@ def execute_all_async(self, method_name: str, *args, **kwargs): result.append(remote_call.remote(*sliced_args, **sliced_kwargs)) return result - return [getattr(worker, method_name).remote(*args, **kwargs) for worker in self._workers] + return [ + getattr(worker, method_name).remote(*args, **kwargs) + for worker in self._workers + ] @property def master_address(self): @@ -419,7 +499,9 @@ def _bind_workers_method_to_parent(cls, key, user_defined_cls): for method_name in dir(user_defined_cls): try: method = getattr(user_defined_cls, method_name) - assert callable(method), f"{method_name} in {user_defined_cls} is not callable" + assert callable( + method + ), f"{method_name} in {user_defined_cls} is not callable" except Exception: # if it is a property, it will fail because Class doesn't have instance property continue @@ -462,9 +544,9 @@ def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): if worker_cls is None: worker_cls = cls.cls.__ray_actor_class__.__base__ else: - assert worker_cls == cls.cls.__ray_actor_class__.__base__, ( - "the worker class should be the same when share the same process" - ) + assert ( + worker_cls == cls.cls.__ray_actor_class__.__base__ + ), "the worker class should be the same when share the same process" cls_dict[key] = cls.cls init_args_dict[key] = {"args": cls.args, "kwargs": cls.kwargs} @@ -480,7 +562,8 @@ def __init__(self): # directly instantiate the class without remote with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): self.worker_dict[key] = user_defined_cls( - *init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {}) + *init_args_dict[key].get("args", ()), + **init_args_dict[key].get("kwargs", {}), ) # now monkey-patch the methods from inner class to WorkerDict diff --git a/Agent0/curriculum_train/verl/trainer/config.py b/Agent0/curriculum_train/verl/trainer/config.py index ef2852d..3a18369 100644 --- a/Agent0/curriculum_train/verl/trainer/config.py +++ b/Agent0/curriculum_train/verl/trainer/config.py @@ -72,6 +72,7 @@ class AlgorithmConfig: kl_target: float = 0.0 mock_data: str = "" + @dataclass class TrainerConfig: total_epochs: int = 10 @@ -93,9 +94,13 @@ class TrainerConfig: def post_init(self): if self.save_checkpoint_path is None: - self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name) + self.save_checkpoint_path = os.path.join( + "checkpoints", self.project_name, self.experiment_name + ) - self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path + self.save_checkpoint_path = os.path.abspath( + self.save_checkpoint_path + ) # ray job uses absolute path if self.load_checkpoint_path is not None: self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path) @@ -110,7 +115,9 @@ class PPOConfig: def post_init(self): self.worker.rollout.prompt_length = self.data.max_prompt_length self.worker.rollout.response_length = self.data.max_response_length - self.worker.rollout.trust_remote_code = self.worker.actor.model.trust_remote_code + self.worker.rollout.trust_remote_code = ( + self.worker.actor.model.trust_remote_code + ) self.worker.actor.disable_kl = self.algorithm.disable_kl self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss self.worker.actor.kl_penalty = self.algorithm.kl_penalty diff --git a/Agent0/curriculum_train/verl/trainer/core_algos.py b/Agent0/curriculum_train/verl/trainer/core_algos.py index 86f9410..17846d0 100644 --- a/Agent0/curriculum_train/verl/trainer/core_algos.py +++ b/Agent0/curriculum_train/verl/trainer/core_algos.py @@ -46,7 +46,8 @@ def update(self, current_kl: float, n_steps: int) -> None: class AdaptiveKLController(KLController): """Adaptive KL controller described in: https://arxiv.org/pdf/1909.08593.pdf - Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L54""" + Copied from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L54 + """ def __init__(self, init_kl_coef: float, target_kl: float, horizon: float): self.kl_coef = init_kl_coef @@ -63,7 +64,8 @@ def update(self, current_kl: float, n_steps: int) -> None: class FixedKLController(KLController): """Fixed KL controller. - Copeid from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L72""" + Copeid from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/utils.py#L72 + """ def __init__(self, init_kl_coef: float): self.kl_coef = init_kl_coef @@ -77,7 +79,9 @@ def get_kl_controller(algorithm_config: "AlgorithmConfig") -> KLController: if algorithm_config.kl_type == "fixed": kl_ctrl = FixedKLController(init_kl_coef=algorithm_config.kl_coef) elif algorithm_config.kl_type == "adaptive": - assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}." + assert ( + algorithm_config.kl_horizon > 0 + ), f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}." kl_ctrl = AdaptiveKLController( init_kl_coef=algorithm_config.kl_coef, target_kl=algorithm_config.kl_target, @@ -136,7 +140,10 @@ def compute_gae_advantage_return( # NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. @torch.no_grad() def compute_grpo_outcome_advantage( - token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: torch.Tensor, eps: float = 1e-6 + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: torch.Tensor, + eps: float = 1e-6, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for GRPO, operating only on Outcome reward @@ -251,7 +258,9 @@ def compute_reinforce_plus_plus_outcome_advantage( @torch.no_grad() def compute_remax_outcome_advantage( - token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, response_mask: torch.Tensor + token_level_rewards: torch.Tensor, + reward_baselines: torch.Tensor, + response_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for ReMax, operating only on Outcome reward @@ -333,7 +342,11 @@ def compute_policy_loss( # see: https://github.com/pytorch/pytorch/issues/10729 ratio = torch.exp(negative_approx_kl) clipped_ratio = torch.exp( - torch.clamp(negative_approx_kl, np.log(1.0 - clip_ratio_low), np.log(1.0 + clip_ratio_high)) + torch.clamp( + negative_approx_kl, + np.log(1.0 - clip_ratio_low), + np.log(1.0 + clip_ratio_high), + ) ) pg_loss = -advantages * ratio @@ -342,9 +355,15 @@ def compute_policy_loss( clipped_pg_loss_higher = torch.max(pg_loss, pg_loss2) # clip if pg_loss < pg_loss2 pg_clipfrac_higher = (pg_loss < pg_loss2).float() - clipped_pg_loss_lower = torch.min(clipped_pg_loss_higher, pg_loss3) # clip if pg_loss > pg_loss3 and adv < 0 - final_pg_loss = torch.where(advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher) - pg_clipfrac_lower = (clipped_pg_loss_higher > pg_loss3).float() * (advantages < 0).float() + clipped_pg_loss_lower = torch.min( + clipped_pg_loss_higher, pg_loss3 + ) # clip if pg_loss > pg_loss3 and adv < 0 + final_pg_loss = torch.where( + advantages < 0, clipped_pg_loss_lower, clipped_pg_loss_higher + ) + pg_clipfrac_lower = (clipped_pg_loss_higher > pg_loss3).float() * ( + advantages < 0 + ).float() final_pg_loss = VF.masked_mean(final_pg_loss, response_mask) pg_clipfrac_higher = VF.masked_mean(pg_clipfrac_higher, response_mask) @@ -383,15 +402,21 @@ def compute_value_loss( The ratio of vf being clipped """ - vpredclipped = torch.clamp(vpreds, values - cliprange_value, values + cliprange_value) + vpredclipped = torch.clamp( + vpreds, values - cliprange_value, values + cliprange_value + ) vf_loss1 = torch.square(vpreds - returns) vf_loss2 = torch.square(vpredclipped - returns) - vf_loss = 0.5 * VF.masked_mean(torch.max(vf_loss1, vf_loss2), action_mask) # clip if vf_loss1 < vf_loss2 + vf_loss = 0.5 * VF.masked_mean( + torch.max(vf_loss1, vf_loss2), action_mask + ) # clip if vf_loss1 < vf_loss2 vf_clipfrac = VF.masked_mean((vf_loss1 < vf_loss2).float(), action_mask) return vf_loss, vf_clipfrac -def compute_kl(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str) -> torch.Tensor: +def compute_kl( + log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, kl_penalty: str +) -> torch.Tensor: """Compute KL divergence given log_probs and ref_log_probs. Adapted from https://github.com/huggingface/trl/blob/v0.11.0/trl/trainer/ppo_trainer.py#L1150 @@ -423,6 +448,8 @@ def compute_kl(log_probs: torch.FloatTensor, ref_log_probs: torch.FloatTensor, k return torch.clamp(kld, min=-10, max=10) if kl_penalty == "full": - return F.kl_div(ref_log_probs, log_probs, log_target=True, reduction="none").sum(-1) + return F.kl_div( + ref_log_probs, log_probs, log_target=True, reduction="none" + ).sum(-1) raise NotImplementedError(f"Unknown KL penalty: {kl_penalty}.") diff --git a/Agent0/curriculum_train/verl/trainer/data_loader.py b/Agent0/curriculum_train/verl/trainer/data_loader.py index cb6881b..40d9d5e 100644 --- a/Agent0/curriculum_train/verl/trainer/data_loader.py +++ b/Agent0/curriculum_train/verl/trainer/data_loader.py @@ -23,7 +23,11 @@ from .config import DataConfig -def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, processor: Optional[ProcessorMixin]) -> None: +def create_dataloader( + config: DataConfig, + tokenizer: PreTrainedTokenizer, + processor: Optional[ProcessorMixin], +) -> None: train_dataset = RLHFDataset( data_path=config.train_files, tokenizer=tokenizer, @@ -42,7 +46,9 @@ def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, proces if config.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(config.seed) - sampler = RandomSampler(data_source=train_dataset, generator=train_dataloader_generator) + sampler = RandomSampler( + data_source=train_dataset, generator=train_dataloader_generator + ) else: sampler = SequentialSampler(data_source=train_dataset) @@ -72,7 +78,9 @@ def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, proces ) val_dataloader = StatefulDataLoader( dataset=val_dataset, - batch_size=len(val_dataset) if config.val_batch_size == -1 else config.val_batch_size, + batch_size=( + len(val_dataset) if config.val_batch_size == -1 else config.val_batch_size + ), shuffle=False, num_workers=8, collate_fn=collate_fn, diff --git a/Agent0/curriculum_train/verl/trainer/main.py b/Agent0/curriculum_train/verl/trainer/main.py index 2c552bd..c1e8986 100644 --- a/Agent0/curriculum_train/verl/trainer/main.py +++ b/Agent0/curriculum_train/verl/trainer/main.py @@ -65,20 +65,28 @@ def run(self, config: PPOConfig): Role.Critic: global_pool_id, Role.RefPolicy: global_pool_id, } - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager = ResourcePoolManager( + resource_pool_spec=resource_pool_spec, mapping=mapping + ) if config.worker.reward.reward_type == "sequential": RewardManager = SequentialFunctionRewardManager elif config.worker.reward.reward_type == "batch": RewardManager = BatchFunctionRewardManager else: - raise NotImplementedError(f"Unknown reward type {config.worker.reward.reward_type}.") + raise NotImplementedError( + f"Unknown reward type {config.worker.reward.reward_type}." + ) - RemoteRewardManager = ray.remote(RewardManager).options(num_cpus=config.worker.reward.num_cpus) + RemoteRewardManager = ray.remote(RewardManager).options( + num_cpus=config.worker.reward.num_cpus + ) reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer) val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer) - train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor) + train_dataloader, val_dataloader = create_dataloader( + config.data, tokenizer, processor + ) trainer = RayPPOTrainer( config=config, @@ -99,10 +107,10 @@ def run(self, config: PPOConfig): def main(): cli_args = OmegaConf.from_cli() default_config = OmegaConf.structured(PPOConfig()) - with open('tokens.json', 'r') as f: + with open("tokens.json", "r") as f: tokens = json.load(f) - os.environ['HF_TOKEN'] = tokens['huggingface'] - os.environ['WANDB_API_KEY'] = tokens['wandb'] + os.environ["HF_TOKEN"] = tokens["huggingface"] + os.environ["WANDB_API_KEY"] = tokens["wandb"] if hasattr(cli_args, "config"): config_path = cli_args.pop("config", None) file_config = OmegaConf.load(config_path) @@ -123,7 +131,7 @@ def main(): "PYTHONUNBUFFERED": "1", } } - ray.init(runtime_env=runtime_env,num_cpus=16) + ray.init(runtime_env=runtime_env, num_cpus=16) runner = Runner.remote() ray.get(runner.run.remote(ppo_config)) diff --git a/Agent0/curriculum_train/verl/trainer/metrics.py b/Agent0/curriculum_train/verl/trainer/metrics.py index 02cd233..b305af5 100644 --- a/Agent0/curriculum_train/verl/trainer/metrics.py +++ b/Agent0/curriculum_train/verl/trainer/metrics.py @@ -73,7 +73,9 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str "critic/values/max": torch.max(valid_values).detach().item(), "critic/values/min": torch.min(valid_values).detach().item(), # vf explained var - "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)) + .detach() + .item(), } if use_critic else {} @@ -82,35 +84,50 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> Dict[str "response_length/mean": torch.mean(response_length).detach().item(), "response_length/max": torch.max(response_length).detach().item(), "response_length/min": torch.min(response_length).detach().item(), - "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + "response_length/clip_ratio": torch.mean( + torch.eq(response_length, max_response_length).float() + ) .detach() .item(), # prompt length "prompt_length/mean": torch.mean(prompt_length).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), "prompt_length/min": torch.min(prompt_length).detach().item(), - "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + "prompt_length/clip_ratio": torch.mean( + torch.eq(prompt_length, max_prompt_length).float() + ) + .detach() + .item(), } return metrics -def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]: +def compute_timing_metrics( + batch: DataProto, timing_raw: Dict[str, float] +) -> Dict[str, Any]: num_response_tokens = torch.sum(batch.batch["response_mask"]).item() num_overall_tokens = sum(batch.meta_info["global_token_num"]) num_tokens_of_section = { **dict.fromkeys(["gen", "reward"], num_response_tokens), - **dict.fromkeys(["ref", "old", "values", "adv", "update_critic", "update_actor"], num_overall_tokens), + **dict.fromkeys( + ["ref", "old", "values", "adv", "update_critic", "update_actor"], + num_overall_tokens, + ), } return { **{f"timing_s/{name}": value for name, value in timing_raw.items()}, **{ - f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + f"timing_per_token_ms/{name}": timing_raw[name] + * 1000 + / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) }, } -def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], num_gpus: int) -> Dict[str, Any]: +def compute_throughout_metrics( + batch: DataProto, timing_raw: Dict[str, float], num_gpus: int +) -> Dict[str, Any]: total_num_tokens = sum(batch.meta_info["global_token_num"]) time = timing_raw["step"] return { diff --git a/Agent0/curriculum_train/verl/trainer/ray_trainer.py b/Agent0/curriculum_train/verl/trainer/ray_trainer.py index 0ba89d3..50fe73f 100644 --- a/Agent0/curriculum_train/verl/trainer/ray_trainer.py +++ b/Agent0/curriculum_train/verl/trainer/ray_trainer.py @@ -33,18 +33,30 @@ from ..protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto from ..single_controller.base import Worker -from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from ..single_controller.ray import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) from ..single_controller.ray.base import create_colocated_worker_cls from ..utils import torch_functional as VF from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt from ..utils.logger import Tracker from ..utils.py_functional import convert_dict_to_str, timer -from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from ..utils.seqlen_balancing import ( + get_seqlen_balanced_partitions, + log_seqlen_unbalance, +) from ..workers.fsdp_workers import FSDPWorker from ..workers.reward import FunctionRewardManager from . import core_algos from .config import PPOConfig -from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics +from .metrics import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + reduce_metrics, +) class Role(IntEnum): @@ -89,7 +101,10 @@ def create_resource_pool(self): # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + process_on_nodes=process_on_nodes, + use_gpu=True, + max_colocate_count=1, + name_prefix=resource_pool_name, ) self.resource_pool_dict[resource_pool_name] = resource_pool @@ -101,28 +116,42 @@ def get_resource_pool(self, role: Role) -> RayResourcePool: def get_num_gpus(self) -> int: """Get the number of gpus in this cluster.""" - return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + return sum( + [ + n_gpus + for process_on_nodes in self.resource_pool_spec.values() + for n_gpus in process_on_nodes + ] + ) def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" gpus_available = ray.available_resources().get("GPU", 0) gpus_required = self.get_num_gpus() if gpus_available < gpus_required: - raise ValueError(f"Total available GPUs {gpus_available} is less than total desired GPUs {gpus_required}.") + raise ValueError( + f"Total available GPUs {gpus_available} is less than total desired GPUs {gpus_required}." + ) -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"): +def apply_kl_penalty( + data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl" +): token_level_scores = data.batch["token_level_scores"] batch_size = data.batch.batch_size[0] response_mask = data.batch["response_mask"] # compute kl between ref_policy and current policy - kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty) + kld = core_algos.compute_kl( + data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty + ) kld = kld * response_mask # (batch_size, response_length) data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld - current_kl = VF.masked_mean(kld, mask=response_mask, dim=-1) # average over sequence + current_kl = VF.masked_mean( + kld, mask=response_mask, dim=-1 + ) # average over sequence current_kl = torch.mean(current_kl, dim=0).item() metrics = {"critic/kl": current_kl, "critic/kl_coef": kl_ctrl.kl_coef} @@ -131,7 +160,12 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penal return data, metrics -def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: float = 1.0, lam: float = 1.0): +def compute_advantage( + data: DataProto, + adv_estimator: AdvantageEstimator, + gamma: float = 1.0, + lam: float = 1.0, +): token_level_rewards = data.batch["token_level_rewards"] response_mask = data.batch["response_mask"] index = data.non_tensor_batch["uid"] @@ -141,7 +175,9 @@ def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: token_level_rewards, values, response_mask, gamma, lam ) elif adv_estimator == AdvantageEstimator.GRPO: - advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards, response_mask, index) + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards, response_mask, index + ) elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( token_level_rewards, response_mask, gamma @@ -152,7 +188,9 @@ def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: token_level_rewards, reward_baselines, response_mask ) elif adv_estimator == AdvantageEstimator.RLOO: - advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards, response_mask, index) + advantages, returns = core_algos.compute_rloo_outcome_advantage( + token_level_rewards, response_mask, index + ) else: raise NotImplementedError @@ -189,9 +227,9 @@ def __init__( self.hybrid_engine = config.worker.hybrid_engine if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, ( - f"ActorRollout should be included in {role_worker_mapping.keys()}." - ) + assert ( + Role.ActorRollout in role_worker_mapping + ), f"ActorRollout should be included in {role_worker_mapping.keys()}." else: raise NotImplementedError @@ -207,7 +245,9 @@ def __init__( else: self.use_reference_policy = False self.kl_ctrl = core_algos.FixedKLController(init_kl_coef=0.0) - print("KL is disabled, no KL metrics will be logged. Please set `kl_coef=0` to log KL metrics.") + print( + "KL is disabled, no KL metrics will be logged. Please set `kl_coef=0` to log KL metrics." + ) if config.algorithm.adv_estimator == AdvantageEstimator.GAE: self.use_critic = True @@ -215,10 +255,14 @@ def __init__( self.use_critic = False if config.algorithm.adv_estimator not in list(AdvantageEstimator): - raise NotImplementedError(f"Unknown advantage estimator: {config.algorithm.adv_estimator}.") + raise NotImplementedError( + f"Unknown advantage estimator: {config.algorithm.adv_estimator}." + ) if config.data.rollout_batch_size % config.worker.actor.global_batch_size != 0: - raise ValueError("Rollout batch size must be divisible by actor global batch size.") + raise ValueError( + "Rollout batch size must be divisible by actor global batch size." + ) if ( config.data.rollout_batch_size * config.worker.rollout.n @@ -228,8 +272,13 @@ def __init__( ) if self.use_critic: - if config.data.rollout_batch_size % config.worker.critic.global_batch_size != 0: - raise ValueError("Rollout batch size must be divisible by critic global batch size.") + if ( + config.data.rollout_batch_size % config.worker.critic.global_batch_size + != 0 + ): + raise ValueError( + "Rollout batch size must be divisible by critic global batch size." + ) if ( config.data.rollout_batch_size * config.worker.rollout.n @@ -239,10 +288,13 @@ def __init__( ) if ( - config.algorithm.adv_estimator in (AdvantageEstimator.GRPO, AdvantageEstimator.RLOO) + config.algorithm.adv_estimator + in (AdvantageEstimator.GRPO, AdvantageEstimator.RLOO) and config.worker.rollout.n == 1 ): - raise ValueError("GRPO and RLOO algorithm need `config.worker.rollout.n > 1`.") + raise ValueError( + "GRPO and RLOO algorithm need `config.worker.rollout.n > 1`." + ) if config.trainer.max_steps is not None: self.training_steps = config.trainer.max_steps @@ -254,7 +306,11 @@ def __init__( print(f"Total training steps: {self.training_steps}") def _maybe_log_val_generations( - self, inputs: List[str], outputs: List[str], labels: List[str], scores: List[float] + self, + inputs: List[str], + outputs: List[str], + labels: List[str], + scores: List[float], ) -> None: """Log a table of validation samples""" if self.config.trainer.val_generations_to_log <= 0: @@ -280,7 +336,10 @@ def _validate(self) -> Dict[str, Any]: test_batch = DataProto.from_single_dict(batch_dict) # Store original inputs input_ids = test_batch.batch["input_ids"] - input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + input_texts = [ + self.tokenizer.decode(ids, skip_special_tokens=True) + for ids in input_ids + ] sample_inputs.extend(input_texts) if "multi_modal_data" in test_batch.non_tensor_batch.keys(): @@ -295,23 +354,36 @@ def _validate(self) -> Dict[str, Any]: ) test_gen_batch.meta_info = self.config.worker.rollout.val_override_config - test_gen_batch.meta_info.update({ - "min_pixels": self.config.data.min_pixels, - "max_pixels": self.config.data.max_pixels, - }) - test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) - test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch) - test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size) + test_gen_batch.meta_info.update( + { + "min_pixels": self.config.data.min_pixels, + "max_pixels": self.config.data.max_pixels, + } + ) + test_gen_batch, pad_size = pad_dataproto_to_divisor( + test_gen_batch, self.actor_rollout_wg.world_size + ) + test_output_gen_batch = self.actor_rollout_wg.generate_sequences( + test_gen_batch + ) + test_output_gen_batch = unpad_dataproto( + test_output_gen_batch, pad_size=pad_size + ) # Store generated outputs output_ids = test_output_gen_batch.batch["responses"] - output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + output_texts = [ + self.tokenizer.decode(ids, skip_special_tokens=True) + for ids in output_ids + ] sample_outputs.extend(output_texts) sample_labels.extend(test_batch.non_tensor_batch["ground_truth"].tolist()) test_batch = test_batch.union(test_output_gen_batch) # evaluate using reward_function - reward_tensor, reward_metrics = ray.get(self.val_reward_fn.compute_reward.remote(test_batch)) + reward_tensor, reward_metrics = ray.get( + self.val_reward_fn.compute_reward.remote(test_batch) + ) # Store scores scores = reward_tensor.sum(-1).cpu().tolist() @@ -321,23 +393,36 @@ def _validate(self) -> Dict[str, Any]: for key, value in reward_metrics.items(): reward_metrics_lst[key].extend(value) - self._maybe_log_val_generations(sample_inputs, sample_outputs, sample_labels, sample_scores) + self._maybe_log_val_generations( + sample_inputs, sample_outputs, sample_labels, sample_scores + ) reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item() - val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()} + val_reward_metrics = { + f"val/{key}_reward": value + for key, value in reduce_metrics(reward_metrics_lst).items() + } return {"val/reward_score": reward_score, **val_reward_metrics} def init_workers(self) -> None: """Init resource pool and worker group""" self.resource_pool_manager.create_resource_pool() - self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + self.resource_pool_to_cls = { + pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values() + } # create actor and rollout if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + resource_pool = self.resource_pool_manager.get_resource_pool( + Role.ActorRollout + ) actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.worker, role="actor_rollout" + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.worker, + role="actor_rollout", ) - self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + self.resource_pool_to_cls[resource_pool][ + "actor_rollout" + ] = actor_rollout_cls else: raise NotImplementedError @@ -345,7 +430,9 @@ def init_workers(self) -> None: if self.use_critic: resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) critic_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.Critic], config=self.config.worker, role="critic" + cls=self.role_worker_mapping[Role.Critic], + config=self.config.worker, + role="critic", ) self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls @@ -353,16 +440,22 @@ def init_workers(self) -> None: if self.use_reference_policy: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) ref_policy_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RefPolicy], config=self.config.worker, role="ref" + self.role_worker_mapping[Role.RefPolicy], + config=self.config.worker, + role="ref", ) self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls # create a reward model if reward_fn is None if self.use_reward_model: # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + resource_pool = self.resource_pool_manager.get_resource_pool( + Role.RewardModel + ) rm_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.RewardModel], config=self.config.worker, role="reward" + cls=self.role_worker_mapping[Role.RewardModel], + config=self.config.worker, + role="reward", ) self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls @@ -374,7 +467,9 @@ def init_workers(self) -> None: self.wg_dicts = [] for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls + ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 @@ -399,9 +494,13 @@ def init_workers(self) -> None: def _save_checkpoint(self) -> None: # path: {save_checkpoint_path}/global_step_{global_step}/{actor,critic} remove_obsolete_ckpt( - self.config.trainer.save_checkpoint_path, self.global_step, self.config.trainer.save_limit + self.config.trainer.save_checkpoint_path, + self.global_step, + self.config.trainer.save_limit, + ) + folder_path = os.path.join( + self.config.trainer.save_checkpoint_path, f"global_step_{self.global_step}" ) - folder_path = os.path.join(self.config.trainer.save_checkpoint_path, f"global_step_{self.global_step}") actor_path = os.path.join(folder_path, "actor") self.actor_rollout_wg.save_checkpoint(actor_path) @@ -413,7 +512,9 @@ def _save_checkpoint(self) -> None: dataloader_state_dict = self.train_dataloader.state_dict() torch.save(dataloader_state_dict, dataloader_path) - last_global_step_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER) + last_global_step_path = os.path.join( + self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER + ) with open(last_global_step_path, "w") as f: f.write(str(self.global_step)) @@ -421,38 +522,64 @@ def _load_checkpoint(self) -> None: if self.config.trainer.load_checkpoint_path is None: return - if "global_step_" not in self.config.trainer.load_checkpoint_path.strip(os.path.sep).split(os.path.sep)[-1]: + if ( + "global_step_" + not in self.config.trainer.load_checkpoint_path.strip(os.path.sep).split( + os.path.sep + )[-1] + ): raise ValueError("`load_checkpoint_path` should end with `global_step_*`.") print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}.") - self.global_step = int(self.config.trainer.load_checkpoint_path.strip(os.path.sep).split("global_step_")[-1]) + self.global_step = int( + self.config.trainer.load_checkpoint_path.strip(os.path.sep).split( + "global_step_" + )[-1] + ) actor_path = os.path.join(self.config.trainer.load_checkpoint_path, "actor") self.actor_rollout_wg.load_checkpoint(actor_path) if self.use_critic: - critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic") + critic_path = os.path.join( + self.config.trainer.load_checkpoint_path, "critic" + ) self.critic_wg.load_checkpoint(critic_path) - dataloader_path = os.path.join(self.config.trainer.load_checkpoint_path, "dataloader.pt") + dataloader_path = os.path.join( + self.config.trainer.load_checkpoint_path, "dataloader.pt" + ) if os.path.exists(dataloader_path): dataloader_state_dict = torch.load(dataloader_path, weights_only=False) self.train_dataloader.load_state_dict(dataloader_state_dict) else: - print(f"No dataloader state found at {dataloader_path}, will start from scratch.") + print( + f"No dataloader state found at {dataloader_path}, will start from scratch." + ) - def _balance_batch(self, batch: DataProto, metrics: Dict[str, Any], logging_prefix: str = "global_seqlen") -> None: + def _balance_batch( + self, + batch: DataProto, + metrics: Dict[str, Any], + logging_prefix: str = "global_seqlen", + ) -> None: """Reorder the data on single controller such that each dp rank gets similar total tokens""" attention_mask = batch.batch["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = ( + batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() + ) # (train_batch_size,) world_size = self.actor_rollout_wg.world_size global_partition_lst = get_seqlen_balanced_partitions( global_seqlen_lst, k_partitions=world_size, equal_size=True ) # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + global_idx = torch.tensor( + [j for partition in global_partition_lst for j in partition] + ) batch.reorder(global_idx) global_balance_stats = log_seqlen_unbalance( - seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + seqlen_list=global_seqlen_lst, + partitions=global_partition_lst, + prefix=logging_prefix, ) metrics.update(global_balance_stats) @@ -462,7 +589,9 @@ def fit(self): The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ - self.logger = Tracker(loggers=self.config.trainer.logger, config=self.config.to_dict()) + self.logger = Tracker( + loggers=self.config.trainer.logger, config=self.config.to_dict() + ) self.global_step = 0 val_metrics: Optional[Dict[str, Any]] = None @@ -477,8 +606,12 @@ def fit(self): if self.config.trainer.val_only: return - for _ in tqdm(range(self.config.trainer.total_epochs), desc="Epoch", position=0): - for batch_dict in tqdm(self.train_dataloader, desc="Running step", position=1): + for _ in tqdm( + range(self.config.trainer.total_epochs), desc="Epoch", position=0 + ): + for batch_dict in tqdm( + self.train_dataloader, desc="Running step", position=1 + ): self.global_step += 1 if self.global_step > self.training_steps: break @@ -492,10 +625,12 @@ def fit(self): batch_keys=["input_ids", "attention_mask", "position_ids"], non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"], ) - gen_batch.meta_info.update({ - "min_pixels": self.config.data.min_pixels, - "max_pixels": self.config.data.max_pixels, - }) + gen_batch.meta_info.update( + { + "min_pixels": self.config.data.min_pixels, + "max_pixels": self.config.data.max_pixels, + } + ) else: gen_batch = batch.pop( batch_keys=["input_ids", "attention_mask", "position_ids"], @@ -505,17 +640,25 @@ def fit(self): with timer("step", timing_raw): # generate a batch with timer("gen", timing_raw): # wg: worker group - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + gen_batch_output = self.actor_rollout_wg.generate_sequences( + gen_batch + ) if self.config.algorithm.adv_estimator == "remax": with timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["temperature"] = 0 gen_baseline_batch.meta_info["n"] = 1 - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + gen_baseline_output = ( + self.actor_rollout_wg.generate_sequences( + gen_baseline_batch + ) + ) batch = batch.union(gen_baseline_output) - reward_baseline_tensor, _ = ray.get(self.reward_fn.compute_reward.remote(batch)) + reward_baseline_tensor, _ = ray.get( + self.reward_fn.compute_reward.remote(batch) + ) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) @@ -523,10 +666,13 @@ def fit(self): del gen_baseline_batch, gen_baseline_output batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + [str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object, ) # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True) + batch = batch.repeat( + repeat_times=self.config.worker.rollout.n, interleave=True + ) batch = batch.union(gen_batch_output) # balance the number of valid tokens on each dp rank. @@ -535,7 +681,9 @@ def fit(self): self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum( + batch.batch["attention_mask"], dim=-1 + ).tolist() # compute reward with timer("reward", timing_raw): @@ -549,7 +697,9 @@ def fit(self): # compute ref_log_probs if self.use_reference_policy: with timer("ref", timing_raw): - ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch) + ref_log_probs = self.ref_policy_wg.compute_ref_log_probs( + batch + ) batch = batch.union(ref_log_probs) # compute values @@ -562,16 +712,26 @@ def fit(self): # get token level scores reward_tensor, reward_metrics = ray.get(reward_ref) batch.batch["token_level_scores"] = reward_tensor - reward_metrics = {f"reward/{k}": v for k, v in reduce_metrics(reward_metrics).items()} + reward_metrics = { + f"reward/{k}": v + for k, v in reduce_metrics(reward_metrics).items() + } metrics.update(reward_metrics) # apply kl penalty if available - if not self.config.algorithm.use_kl_loss and self.use_reference_policy: + if ( + not self.config.algorithm.use_kl_loss + and self.use_reference_policy + ): # apply kl penalty to reward - batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty) + batch, kl_metrics = apply_kl_penalty( + batch, self.kl_ctrl, self.config.algorithm.kl_penalty + ) metrics.update(kl_metrics) else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + batch.batch["token_level_rewards"] = batch.batch[ + "token_level_scores" + ] # compute advantages, executed on the driver process batch = compute_advantage( @@ -608,15 +768,26 @@ def fit(self): metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0: + if ( + self.config.trainer.save_freq > 0 + and self.global_step % self.config.trainer.save_freq == 0 + ): with timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics num_gpus = self.resource_pool_manager.get_num_gpus() - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, num_gpus=num_gpus)) + metrics.update( + compute_data_metrics(batch=batch, use_critic=self.use_critic) + ) + metrics.update( + compute_timing_metrics(batch=batch, timing_raw=timing_raw) + ) + metrics.update( + compute_throughout_metrics( + batch=batch, timing_raw=timing_raw, num_gpus=num_gpus + ) + ) self.logger.log(data=metrics, step=self.global_step) @@ -632,5 +803,8 @@ def fit(self): print(f"Final validation metrics: {convert_dict_to_str(val_metrics)}") - if self.config.trainer.save_freq <= 0 or self.global_step % self.config.trainer.save_freq != 0: + if ( + self.config.trainer.save_freq <= 0 + or self.global_step % self.config.trainer.save_freq != 0 + ): self._save_checkpoint() diff --git a/Agent0/curriculum_train/verl/utils/checkpoint/checkpoint_manager.py b/Agent0/curriculum_train/verl/utils/checkpoint/checkpoint_manager.py index 749b60c..02bb5d6 100644 --- a/Agent0/curriculum_train/verl/utils/checkpoint/checkpoint_manager.py +++ b/Agent0/curriculum_train/verl/utils/checkpoint/checkpoint_manager.py @@ -85,7 +85,9 @@ def local_mkdir(path: str) -> str: os.makedirs(path, exist_ok=True) except Exception as e: print(f"Warning: Failed to acquire lock for {path}: {e}") - os.makedirs(path, exist_ok=True) # even if the lock is not acquired, try to create the directory + os.makedirs( + path, exist_ok=True + ) # even if the lock is not acquired, try to create the directory return path @@ -107,7 +109,9 @@ def load_rng_state(rng_state: Dict[str, Any]): random.setstate(rng_state["random"]) -def find_latest_ckpt_path(path: Optional[str] = None, directory_format: str = "global_step_{}") -> Optional[str]: +def find_latest_ckpt_path( + path: Optional[str] = None, directory_format: str = "global_step_{}" +) -> Optional[str]: if path is None: return None @@ -135,7 +139,12 @@ def get_checkpoint_tracker_filename(root_path: str) -> str: return os.path.join(root_path, CHECKPOINT_TRACKER) -def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"): +def remove_obsolete_ckpt( + path: str, + global_step: int, + save_limit: int = -1, + directory_format: str = "global_step_{}", +): """ Remove the obsolete checkpoints that exceed the save_limit. """ diff --git a/Agent0/curriculum_train/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/Agent0/curriculum_train/verl/utils/checkpoint/fsdp_checkpoint_manager.py index 1318bfe..87a1123 100644 --- a/Agent0/curriculum_train/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/Agent0/curriculum_train/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -17,7 +17,11 @@ import torch import torch.distributed as dist -from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict, set_state_dict +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_state_dict, + set_state_dict, +) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin @@ -53,12 +57,22 @@ def load_checkpoint(self, path: Optional[str] = None): return # every rank download its own checkpoint - model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") - optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") - extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") + model_path = os.path.join( + path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + optim_path = os.path.join( + path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + extra_path = os.path.join( + path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" + ) print(f"[rank-{self.rank}]: Loading model from {os.path.abspath(model_path)}.") - print(f"[rank-{self.rank}]: Loading optimizer from {os.path.abspath(optim_path)}.") - print(f"[rank-{self.rank}]: Loading extra_state from {os.path.abspath(extra_path)}.") + print( + f"[rank-{self.rank}]: Loading optimizer from {os.path.abspath(optim_path)}." + ) + print( + f"[rank-{self.rank}]: Loading extra_state from {os.path.abspath(extra_path)}." + ) model_state_dict = torch.load(model_path, weights_only=False) optim_state_dict = torch.load(optim_path, weights_only=False) extra_state_dict = torch.load(extra_path, weights_only=False) @@ -83,18 +97,28 @@ def save_checkpoint(self, path: str): # every rank will save its own model and optim shard state_dict_options = StateDictOptions(cpu_offload=True) - model_state_dict, optim_state_dict = get_state_dict(self.model, self.optimizer, options=state_dict_options) + model_state_dict, optim_state_dict = get_state_dict( + self.model, self.optimizer, options=state_dict_options + ) extra_state_dict = { "lr_scheduler": self.lr_scheduler.state_dict(), "rng": self.get_rng_state(), } - model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") - optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") - extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") + model_path = os.path.join( + path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + optim_path = os.path.join( + path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + extra_path = os.path.join( + path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" + ) print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.") print(f"[rank-{self.rank}]: Saving optimizer to {os.path.abspath(optim_path)}.") - print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.") + print( + f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}." + ) torch.save(model_state_dict, model_path) torch.save(optim_state_dict, optim_path) torch.save(extra_state_dict, extra_path) diff --git a/Agent0/curriculum_train/verl/utils/code_executor.py b/Agent0/curriculum_train/verl/utils/code_executor.py index 82b9c67..29c4f60 100644 --- a/Agent0/curriculum_train/verl/utils/code_executor.py +++ b/Agent0/curriculum_train/verl/utils/code_executor.py @@ -2,7 +2,8 @@ import json import re -SANDBOX_API_URL = 'http://172.22.1.105:8080/run_code' +SANDBOX_API_URL = "http://172.22.1.105:8080/run_code" + def execute_code_in_sandbox(code: str) -> str: """ @@ -14,15 +15,12 @@ def execute_code_in_sandbox(code: str) -> str: Returns: ๆ‰ง่กŒ็ป“ๆžœ๏ผˆstdout๏ผ‰๏ผŒๅฆ‚ๆžœๅ‡บ้”™ๅˆ™่ฟ”ๅ›ž้”™่ฏฏไฟกๆฏใ€‚ """ - payload = { - "code": code, - "language": "python" - } - headers = { - 'Content-Type': 'application/json' - } - - response = requests.post(SANDBOX_API_URL, headers=headers, data=json.dumps(payload), timeout=10) + payload = {"code": code, "language": "python"} + headers = {"Content-Type": "application/json"} + + response = requests.post( + SANDBOX_API_URL, headers=headers, data=json.dumps(payload), timeout=10 + ) response.raise_for_status() result = response.json() @@ -37,13 +35,13 @@ def execute_code_in_sandbox(code: str) -> str: return f"{result}" -if __name__ == '__main__': +if __name__ == "__main__": hello_world_code = 'print("Hello, world!")' print(f"Executing code:\n---\n{hello_world_code}\n---") output = execute_code_in_sandbox(hello_world_code) print(f"Result:\n---\n{output}\n---") - error_code = 'print(1 / 0)' + error_code = "print(1 / 0)" print(f"Executing code with error:\n---\n{error_code}\n---") output = execute_code_in_sandbox(error_code) - print(f"Result:\n---\n{output}\n---") \ No newline at end of file + print(f"Result:\n---\n{output}\n---") diff --git a/Agent0/curriculum_train/verl/utils/dataset.py b/Agent0/curriculum_train/verl/utils/dataset.py index 0002e53..fc0089a 100644 --- a/Agent0/curriculum_train/verl/utils/dataset.py +++ b/Agent0/curriculum_train/verl/utils/dataset.py @@ -32,6 +32,8 @@ import json import random + + def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]: tensors = defaultdict(list) non_tensors = defaultdict(list) @@ -51,8 +53,9 @@ def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]: return {**tensors, **non_tensors} - -def process_image(image: Union[Dict[str, Any], ImageObject, str], min_pixels: int, max_pixels: int) -> ImageObject: +def process_image( + image: Union[Dict[str, Any], ImageObject, str], min_pixels: int, max_pixels: int +) -> ImageObject: if isinstance(image, str): image = Image.open(image) elif isinstance(image, dict): @@ -62,12 +65,16 @@ def process_image(image: Union[Dict[str, Any], ImageObject, str], min_pixels: in if (image.width * image.height) > max_pixels: resize_factor = math.sqrt(max_pixels / (image.width * image.height)) - width, height = int(image.width * resize_factor), int(image.height * resize_factor) + width, height = int(image.width * resize_factor), int( + image.height * resize_factor + ) image = image.resize((width, height)) if (image.width * image.height) < min_pixels: resize_factor = math.sqrt(min_pixels / (image.width * image.height)) - width, height = int(image.width * resize_factor), int(image.height * resize_factor) + width, height = int(image.width * resize_factor), int( + image.height * resize_factor + ) image = image.resize((width, height)) if image.mode != "RGB": @@ -128,11 +135,15 @@ def __init__( if "questioner_format_with_persona" in self.format_prompt: print("load personas") - personas_dataset = load_dataset("proj-persona/PersonaHub", "math", split="train") - self.personas = [item['input persona'] for item in personas_dataset] + personas_dataset = load_dataset( + "proj-persona/PersonaHub", "math", split="train" + ) + self.personas = [item["input persona"] for item in personas_dataset] # self.personas = self.personas.select(range(100)) if self.filter_overlong_prompts: - self.dataset = self.dataset.filter(self._filter_overlong_prompts, desc="Filtering overlong prompts") + self.dataset = self.dataset.filter( + self._filter_overlong_prompts, desc="Filtering overlong prompts" + ) def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]: prompt_str: str = example[self.prompt_key] @@ -154,15 +165,15 @@ def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]: r"\boxed{final_answer}" "\n\n" "Do NOT output anything elseโ€”no explanations, no extra markup." - ) + ), }, { "role": "user", "content": ( "Generate one new, challenging reasoning question now. " "Remember to format the output exactly as instructed." - ) - } + ), + }, ] if "questioner_format" in self.format_prompt: # print('detected questioner_format') @@ -182,31 +193,28 @@ def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]: r"\boxed{final_answer}" "\n\n" "Do NOT output anything elseโ€”no explanations, no extra markup." - ) + ), }, { "role": "user", "content": ( "Generate one new, challenging reasoning question now. " "Remember to format the output exactly as instructed." - ) - } + ), + }, ] if "solver_format" in self.format_prompt: return [ { - "role": "system", - "content": r"Please reason step by step, and put your final answer within \boxed{}." + "role": "system", + "content": r"Please reason step by step, and put your final answer within \boxed{}.", }, - { - "role": "user", - "content": prompt_str - } - ] + {"role": "user", "content": prompt_str}, + ] if self.format_prompt: format_prompt = Template(self.format_prompt.strip()) prompt_str = format_prompt.render(content=prompt_str) - + if self.image_key in example: # https://huggingface.co/docs/transformers/en/tasks/image_text_to_text content_list = [] @@ -223,16 +231,29 @@ def _build_messages(self, example: Dict[str, Any]) -> List[Dict[str, Any]]: def _filter_overlong_prompts(self, example: Dict[str, Any]) -> bool: messages = self._build_messages(example) - processing_class = self.processor if self.processor is not None else self.tokenizer + processing_class = ( + self.processor if self.processor is not None else self.tokenizer + ) if self.tokenizer.chat_template: return ( - len(processing_class.apply_chat_template(messages, add_generation_prompt=True)) <= self.max_prompt_length + len( + processing_class.apply_chat_template( + messages, add_generation_prompt=True + ) + ) + <= self.max_prompt_length ) else: return ( - len("system: " + messages[0]["content"] + '\n' + "user: " + messages[1]["content"]) <= self.max_prompt_length + len( + "system: " + + messages[0]["content"] + + "\n" + + "user: " + + messages[1]["content"] + ) + <= self.max_prompt_length ) - def __len__(self): return len(self.dataset) @@ -242,26 +263,46 @@ def __getitem__(self, index): messages = self._build_messages(example) if self.image_key in example: - prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + prompt = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) raw_image_data = example.pop(self.image_key) images = [ - process_image(image, min_pixels=self.min_pixels, max_pixels=self.max_pixels) + process_image( + image, min_pixels=self.min_pixels, max_pixels=self.max_pixels + ) for image in raw_image_data ] - model_inputs = self.processor(images, [prompt], add_special_tokens=False, return_tensors="pt") + model_inputs = self.processor( + images, [prompt], add_special_tokens=False, return_tensors="pt" + ) input_ids = model_inputs.pop("input_ids")[0] attention_mask = model_inputs.pop("attention_mask")[0] example["multi_modal_data"] = {"image": raw_image_data} else: if self.tokenizer.chat_template: - prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) else: - prompt = "system: " + messages[0]["content"] + '\n' + "user: " + messages[1]["content"] - model_inputs = self.tokenizer([prompt], add_special_tokens=False, return_tensors="pt") + prompt = ( + "system: " + + messages[0]["content"] + + "\n" + + "user: " + + messages[1]["content"] + ) + model_inputs = self.tokenizer( + [prompt], add_special_tokens=False, return_tensors="pt" + ) input_ids = model_inputs.pop("input_ids")[0] attention_mask = model_inputs.pop("attention_mask")[0] - if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor": + if ( + self.processor is not None + and self.processor.image_processor.__class__.__name__ + == "Qwen2VLImageProcessor" + ): # qwen2vl mrope position_ids = get_rope_index( self.processor, @@ -270,7 +311,9 @@ def __getitem__(self, index): attention_mask=attention_mask, ) # (3, seq_length) else: - position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,) + position_ids = torch.clip( + attention_mask.cumsum(dim=0) - 1, min=0, max=None + ) # (seq_length,) input_ids, attention_mask, position_ids = VF.postprocess_data( input_ids=input_ids, @@ -288,7 +331,9 @@ def __getitem__(self, index): elif self.truncation == "right": raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] elif self.truncation == "error": - raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") + raise RuntimeError( + f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}." + ) example["input_ids"] = input_ids example["attention_mask"] = attention_mask diff --git a/Agent0/curriculum_train/verl/utils/flops_counter.py b/Agent0/curriculum_train/verl/utils/flops_counter.py index dee7623..4e23536 100644 --- a/Agent0/curriculum_train/verl/utils/flops_counter.py +++ b/Agent0/curriculum_train/verl/utils/flops_counter.py @@ -66,7 +66,9 @@ class FlopsCounter: def __init__(self, config: "LlamaConfig"): if config.model_type not in VALID_MODLE_TYPE: - print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.") + print( + f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero." + ) self.estimate_func = { "llama": self._estimate_llama_flops, @@ -76,10 +78,14 @@ def __init__(self, config: "LlamaConfig"): } self.config = config - def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float: + def _estimate_unknown_flops( + self, tokens_sum: int, batch_seqlens: List[int], delta_time: float + ) -> float: return 0 - def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float: + def _estimate_llama_flops( + self, tokens_sum: int, batch_seqlens: List[int], delta_time: float + ) -> float: hidden_size = self.config.hidden_size vocab_size = self.config.vocab_size num_hidden_layers = self.config.num_hidden_layers @@ -95,7 +101,9 @@ def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta # non-attn per layer parm # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp mlp_N = hidden_size * intermediate_size * 3 - attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + attn_linear_N = hidden_size * ( + q_size + k_size + v_size + num_attention_heads * head_dim + ) emd_and_lm_head_N = vocab_size * hidden_size * 2 # non-attn all_layer parm dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N @@ -107,14 +115,18 @@ def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta for seqlen in batch_seqlens: seqlen_square_sum += seqlen * seqlen - attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + attn_qkv_flops = ( + 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + ) # all_layer & all_token fwd & bwd flops flops_all_token = dense_N_flops + attn_qkv_flops flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 return flops_achieved - def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]: + def estimate_flops( + self, batch_seqlens: List[int], delta_time: float + ) -> Tuple[float, float]: """ Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken. @@ -127,7 +139,9 @@ def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[f promised_flops (float): The expected FLOPS of the current device. """ tokens_sum = sum(batch_seqlens) - func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops) + func = self.estimate_func.get( + self.config.model_type, self._estimate_unknown_flops + ) estimated_flops = func(tokens_sum, batch_seqlens, delta_time) promised_flops = get_device_flops() return estimated_flops, promised_flops diff --git a/Agent0/curriculum_train/verl/utils/fsdp_utils.py b/Agent0/curriculum_train/verl/utils/fsdp_utils.py index 1ca563a..13e3cf7 100644 --- a/Agent0/curriculum_train/verl/utils/fsdp_utils.py +++ b/Agent0/curriculum_train/verl/utils/fsdp_utils.py @@ -27,23 +27,32 @@ from transformers.trainer_pt_utils import get_module_class_from_name -def get_init_fn(model: nn.Module, device: Union[str, torch.device]) -> Callable[[nn.Module], None]: +def get_init_fn( + model: nn.Module, device: Union[str, torch.device] +) -> Callable[[nn.Module], None]: param_occurrence = defaultdict(int) for _, param in model.named_parameters(remove_duplicate=False): param_occurrence[param] += 1 - duplicated_params = {param for param in param_occurrence.keys() if param_occurrence[param] > 1} + duplicated_params = { + param for param in param_occurrence.keys() if param_occurrence[param] > 1 + } materialized_params = {} def init_fn(module: nn.Module): for name, param in module.named_parameters(recurse=False): if param in duplicated_params: module._parameters[name] = materialized_params.setdefault( - param, nn.Parameter(torch.empty_like(param.data, device=device), requires_grad=param.requires_grad) + param, + nn.Parameter( + torch.empty_like(param.data, device=device), + requires_grad=param.requires_grad, + ), ) else: module._parameters[name] = nn.Parameter( - torch.empty_like(param.data, device=device), requires_grad=param.requires_grad + torch.empty_like(param.data, device=device), + requires_grad=param.requires_grad, ) return init_fn @@ -63,7 +72,9 @@ def get_fsdp_wrap_policy(model: PreTrainedModel): else: transformer_cls_to_wrap.add(transformer_cls) - return partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap) + return partial( + transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap + ) @torch.no_grad() diff --git a/Agent0/curriculum_train/verl/utils/logger/gen_logger.py b/Agent0/curriculum_train/verl/utils/logger/gen_logger.py index b62cde6..62d618f 100644 --- a/Agent0/curriculum_train/verl/utils/logger/gen_logger.py +++ b/Agent0/curriculum_train/verl/utils/logger/gen_logger.py @@ -38,7 +38,9 @@ def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: .. class ConsoleGenerationLogger(GenerationLogger): def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: for inp, out, lab, score in samples: - print(f"[prompt] {inp}\n[output] {out}\n[ground_truth] {lab}\n[score] {score}\n") + print( + f"[prompt] {inp}\n[output] {out}\n[ground_truth] {lab}\n[score] {score}\n" + ) @dataclass @@ -46,7 +48,15 @@ class WandbGenerationLogger(GenerationLogger): def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: # Create column names for all samples columns = ["step"] + sum( - [[f"input_{i + 1}", f"output_{i + 1}", f"label_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], + [ + [ + f"input_{i + 1}", + f"output_{i + 1}", + f"label_{i + 1}", + f"score_{i + 1}", + ] + for i in range(len(samples)) + ], [], ) @@ -74,7 +84,12 @@ def log(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: swanlab_text_list = [] for i, sample in enumerate(samples): row_text = "\n\n---\n\n".join( - (f"input: {sample[0]}", f"output: {sample[1]}", f"label: {sample[2]}", f"score: {sample[3]}") + ( + f"input: {sample[0]}", + f"output: {sample[1]}", + f"label: {sample[2]}", + f"score: {sample[3]}", + ) ) swanlab_text_list.append(swanlab.Text(row_text, caption=f"sample {i + 1}")) diff --git a/Agent0/curriculum_train/verl/utils/logger/logger.py b/Agent0/curriculum_train/verl/utils/logger/logger.py index a29fb50..cb97513 100644 --- a/Agent0/curriculum_train/verl/utils/logger/logger.py +++ b/Agent0/curriculum_train/verl/utils/logger/logger.py @@ -21,7 +21,12 @@ import torch -from ..py_functional import convert_dict_to_str, flatten_dict, is_package_available, unflatten_dict +from ..py_functional import ( + convert_dict_to_str, + flatten_dict, + is_package_available, + unflatten_dict, +) from .gen_logger import AggregateGenerationsLogger @@ -140,7 +145,11 @@ def finish(self) -> None: class Tracker: - def __init__(self, loggers: Union[str, List[str]] = "console", config: Optional[Dict[str, Any]] = None): + def __init__( + self, + loggers: Union[str, List[str]] = "console", + config: Optional[Dict[str, Any]] = None, + ): if isinstance(loggers, str): loggers = [loggers] @@ -157,7 +166,9 @@ def log(self, data: Dict[str, Any], step: int) -> None: for logger in self.loggers: logger.log(data=data, step=step) - def log_generation(self, samples: List[Tuple[str, str, str, float]], step: int) -> None: + def log_generation( + self, samples: List[Tuple[str, str, str, float]], step: int + ) -> None: self.gen_logger.log(samples, step) def __del__(self): diff --git a/Agent0/curriculum_train/verl/utils/model_utils.py b/Agent0/curriculum_train/verl/utils/model_utils.py index 71d4fe2..2834f10 100644 --- a/Agent0/curriculum_train/verl/utils/model_utils.py +++ b/Agent0/curriculum_train/verl/utils/model_utils.py @@ -32,7 +32,9 @@ def print_gpu_memory_usage(prefix: str = "GPU memory usage") -> None: """Report the current GPU VRAM usage.""" if is_rank0(): free_mem, total_mem = torch.cuda.mem_get_info() - print(f"{prefix}: {(total_mem - free_mem) / (1024**3):.2f} GB / {total_mem / (1024**3):.2f} GB.") + print( + f"{prefix}: {(total_mem - free_mem) / (1024**3):.2f} GB / {total_mem / (1024**3):.2f} GB." + ) def _get_model_size(model: nn.Module, scale: str = "auto") -> Tuple[float, str]: diff --git a/Agent0/curriculum_train/verl/utils/py_functional.py b/Agent0/curriculum_train/verl/utils/py_functional.py index 1a9ed3c..e40d6d7 100644 --- a/Agent0/curriculum_train/verl/utils/py_functional.py +++ b/Agent0/curriculum_train/verl/utils/py_functional.py @@ -57,7 +57,9 @@ def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, An """Union two dict. Will throw an error if there is an item not the same object with the same key.""" for key in dict2.keys(): if key in dict1: - assert dict1[key] == dict2[key], f"{key} in dict1 and dict2 are not the same object" + assert ( + dict1[key] == dict2[key] + ), f"{key} in dict1 and dict2 are not the same object" dict1[key] = dict2[key] @@ -89,7 +91,9 @@ def unflatten_dict(data: Dict[str, Any], sep: str = "/") -> Dict[str, Any]: return unflattened -def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]: +def flatten_dict( + data: Dict[str, Any], parent_key: str = "", sep: str = "/" +) -> Dict[str, Any]: flattened = {} for key, value in data.items(): new_key = parent_key + sep + key if parent_key else key diff --git a/Agent0/curriculum_train/verl/utils/seqlen_balancing.py b/Agent0/curriculum_train/verl/utils/seqlen_balancing.py index 5889784..eaf32b9 100644 --- a/Agent0/curriculum_train/verl/utils/seqlen_balancing.py +++ b/Agent0/curriculum_train/verl/utils/seqlen_balancing.py @@ -99,7 +99,9 @@ def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)]) states_pq: List[State] = [] if equal_size: - assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0" + assert ( + len(seqlen_list) % k_partitions == 0 + ), f"{len(seqlen_list)} % {k_partitions} != 0" for offset in range(0, len(sorted_seqlen_list), k_partitions): items = [] for i in range(k_partitions): @@ -121,9 +123,9 @@ def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): partitions = final_state.get_partitions() if equal_size: for i, partition in enumerate(partitions): - assert len(partition) * k_partitions == len(seqlen_list), ( - f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" - ) + assert len(partition) * k_partitions == len( + seqlen_list + ), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" return partitions @@ -141,13 +143,15 @@ def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool partition_sums[min_idx] += seqlen if equal_size: for i, partition in enumerate(partitions): - assert len(partition) * k_partitions == len(seqlen_list), ( - f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" - ) + assert len(partition) * k_partitions == len( + seqlen_list + ), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" return partitions -def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool): +def get_seqlen_balanced_partitions( + seqlen_list: List[int], k_partitions: int, equal_size: bool +): """get order of seq lengths to make partitions balanced, this is used in balacing sum of seqlength across dp ranks and microbatches Parameters: @@ -163,7 +167,9 @@ def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, eq partitions (List[List[int]]): return k_partitions list containing the index of items. """ - assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" + assert ( + len(seqlen_list) >= k_partitions + ), f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" def _check_and_sort_partitions(partitions): assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" @@ -177,7 +183,9 @@ def _check_and_sort_partitions(partitions): assert seen_idx == set(range(len(seqlen_list))) return sorted_partitions - partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) + partitions = karmarkar_karp( + seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size + ) return _check_and_sort_partitions(partitions) @@ -225,9 +233,9 @@ def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None): """ # this is per local micro_bsz max_seq_len = batch["attention_mask"].shape[-1] - assert max_token_len >= max_seq_len, ( - f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" - ) + assert ( + max_token_len >= max_seq_len + ), f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) total_seqlen = seq_len_effective.sum().item() @@ -240,7 +248,9 @@ def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None): seq_len_effective = seq_len_effective.tolist() assert num_micro_batches <= len(seq_len_effective) - micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) + micro_bsz_idx = get_seqlen_balanced_partitions( + seq_len_effective, num_micro_batches, equal_size=False + ) micro_batches = [] diff --git a/Agent0/curriculum_train/verl/utils/tokenizer.py b/Agent0/curriculum_train/verl/utils/tokenizer.py index b339e2a..bb6717a 100644 --- a/Agent0/curriculum_train/verl/utils/tokenizer.py +++ b/Agent0/curriculum_train/verl/utils/tokenizer.py @@ -15,10 +15,17 @@ from typing import Optional -from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin +from transformers import ( + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizer, + ProcessorMixin, +) -def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> PreTrainedTokenizer: +def get_tokenizer( + model_path: str, override_chat_template: Optional[str] = None, **kwargs +) -> PreTrainedTokenizer: """Create a huggingface pretrained tokenizer.""" tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs) if override_chat_template is not None: @@ -27,7 +34,9 @@ def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None, if tokenizer.bos_token == "" and tokenizer.eos_token == "": # the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance. # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a - print("Found gemma model. Set eos_token and eos_token_id to and 107.") + print( + "Found gemma model. Set eos_token and eos_token_id to and 107." + ) tokenizer.eos_token = "" if tokenizer.pad_token_id is None: @@ -37,7 +46,9 @@ def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None, return tokenizer -def get_processor(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> Optional[ProcessorMixin]: +def get_processor( + model_path: str, override_chat_template: Optional[str] = None, **kwargs +) -> Optional[ProcessorMixin]: """Create a huggingface pretrained processor.""" processor = AutoProcessor.from_pretrained(model_path, **kwargs) if override_chat_template is not None: diff --git a/Agent0/curriculum_train/verl/utils/torch_functional.py b/Agent0/curriculum_train/verl/utils/torch_functional.py index 0bf926e..0b2fe5c 100644 --- a/Agent0/curriculum_train/verl/utils/torch_functional.py +++ b/Agent0/curriculum_train/verl/utils/torch_functional.py @@ -35,7 +35,9 @@ @torch.compiler.disable() -def log_probs_from_logits_flash_attn(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: +def log_probs_from_logits_flash_attn( + logits: torch.Tensor, labels: torch.Tensor +) -> torch.Tensor: output = cross_entropy_loss(logits, labels, inplace_backward=True) if not isinstance(output, tuple): raise ValueError( @@ -69,12 +71,16 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return output.view(*batch_dim) -def masked_mean(values: torch.Tensor, mask: torch.Tensor, dim: int = None, eps: float = 1e-8) -> torch.Tensor: +def masked_mean( + values: torch.Tensor, mask: torch.Tensor, dim: int = None, eps: float = 1e-8 +) -> torch.Tensor: """Compute mean of tensor with a masked values.""" return (values * mask).sum(dim=dim) / (mask.sum(dim=dim) + eps) -def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: +def masked_var( + values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True +) -> torch.Tensor: """Compute variance of tensor with masked values.""" mean = masked_mean(values, mask) centered_values = values - mean @@ -82,7 +88,9 @@ def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) if unbiased: mask_sum = mask.sum() if mask_sum <= 1: - print("The sum of the mask is less than one, which can cause a division by zero.") + print( + "The sum of the mask is less than one, which can cause a division by zero." + ) return variance bessel_correction = mask_sum / (mask_sum - 1) @@ -91,14 +99,18 @@ def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) return variance -def masked_whiten(values: torch.Tensor, mask: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: +def masked_whiten( + values: torch.Tensor, mask: torch.Tensor, eps: float = 1e-8 +) -> torch.Tensor: """Whiten values with masked values.""" mean, var = masked_mean(values, mask), masked_var(values, mask) return (values - mean) * torch.rsqrt(var + eps) def get_response_mask( - response_ids: torch.Tensor, eos_token_id: Union[int, List[int]] = 2, dtype: torch.dtype = torch.long + response_ids: torch.Tensor, + eos_token_id: Union[int, List[int]] = 2, + dtype: torch.dtype = torch.long, ): """Get the mask for the response ids, the mask will be 0 after the first eos token. @@ -132,7 +144,10 @@ def pad_2d_list_to_length( else: target_length = max_response_length - padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response] + padded_response = [ + tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) + for sub_list in response + ] tensor = torch.tensor(padded_response) return tensor @@ -146,8 +161,14 @@ def pad_sequence_to_length( pad_shape = list(tensor.shape) pad_shape[-1] = max_seq_len - tensor.size(-1) - pad_tensor = torch.full(pad_shape, fill_value=pad_token_id, dtype=tensor.dtype, device=tensor.device) - return torch.cat((pad_tensor, tensor), dim=-1) if left_pad else torch.cat((tensor, pad_tensor), dim=-1) + pad_tensor = torch.full( + pad_shape, fill_value=pad_token_id, dtype=tensor.dtype, device=tensor.device + ) + return ( + torch.cat((pad_tensor, tensor), dim=-1) + if left_pad + else torch.cat((tensor, pad_tensor), dim=-1) + ) def postprocess_data( @@ -164,12 +185,17 @@ def postprocess_data( seq_length = len(input_ids) if seq_length < max_length: input_ids = pad_sequence_to_length( - input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad + input_ids, + max_seq_len=max_length, + pad_token_id=pad_token_id, + left_pad=left_pad, ) attention_mask = pad_sequence_to_length( attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad ) - position_ids = pad_sequence_to_length(position_ids, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad) + position_ids = pad_sequence_to_length( + position_ids, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad + ) elif seq_length > max_length: if truncation == "left": # actually, left truncation may not be reasonable input_ids = input_ids[..., -max_length:] @@ -180,7 +206,9 @@ def postprocess_data( attention_mask = attention_mask[..., :max_length] position_ids = position_ids[..., :max_length] elif truncation == "error": - raise RuntimeError(f"Input sequence length {seq_length} is longer than max length {max_length}.") + raise RuntimeError( + f"Input sequence length {seq_length} is longer than max length {max_length}." + ) else: raise NotImplementedError(f"Unknown truncation method {truncation}.") @@ -282,14 +310,18 @@ def step(self, closure=None): momentum_dtype = PrecisionType.to_dtype(group["momentum_dtype"]) variance_dtype = PrecisionType.to_dtype(group["variance_dtype"]) - compensation_buffer_dtype = PrecisionType.to_dtype(group["compensation_buffer_dtype"]) + compensation_buffer_dtype = PrecisionType.to_dtype( + group["compensation_buffer_dtype"] + ) for p in group["params"]: assert isinstance(p, torch.Tensor) # lint if p.grad is None: continue if p.grad.is_sparse: - raise RuntimeError("AnyPrecisionAdamW does not support sparse gradients.") + raise RuntimeError( + "AnyPrecisionAdamW does not support sparse gradients." + ) state = self.state[p] # State initialization @@ -304,7 +336,9 @@ def step(self, closure=None): # optional Kahan summation - accumulated error tracker if use_kahan_summation: - state["compensation"] = torch.zeros_like(p, dtype=compensation_buffer_dtype) + state["compensation"] = torch.zeros_like( + p, dtype=compensation_buffer_dtype + ) # Main processing # update the steps for each param group update @@ -319,13 +353,19 @@ def step(self, closure=None): p.data.mul_(1 - lr * weight_decay) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # update momentum - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # update uncentered variance + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=1 - beta2 + ) # update uncentered variance bias_correction1 = 1 - beta1**step # adjust using bias1 step_size = lr / bias_correction1 - denom_correction = (1 - beta2**step) ** 0.5 # adjust using bias2 and avoids math import - centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(eps, alpha=1) + denom_correction = ( + 1 - beta2**step + ) ** 0.5 # adjust using bias2 and avoids math import + centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( + eps, alpha=1 + ) if use_kahan_summation: # lr update to compensation compensation = state["compensation"] diff --git a/Agent0/curriculum_train/verl/utils/ulysses.py b/Agent0/curriculum_train/verl/utils/ulysses.py index 18e07b4..c34a114 100644 --- a/Agent0/curriculum_train/verl/utils/ulysses.py +++ b/Agent0/curriculum_train/verl/utils/ulysses.py @@ -84,7 +84,9 @@ def gather_seq_scatter_heads( return x -def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor: +def gather_heads_scatter_seq( + x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None +) -> Tensor: """ A func to sync attention result with alltoall in sequence parallel gather head dimension and scatter seq dim: @@ -115,7 +117,9 @@ def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor: return x[slc] -def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor: +def slice_input_tensor( + x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None +) -> Tensor: group = get_ulysses_sequence_parallel_group() if group is None else group sp_world_size = dist.get_world_size(group) sp_rank = get_ulysses_sequence_parallel_rank() @@ -140,7 +144,10 @@ def all_to_all_tensor( ): group = get_ulysses_sequence_parallel_group() if group is None else group seq_world_size = dist.get_world_size(group) - input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)] + input_list = [ + t.contiguous() + for t in torch.tensor_split(local_input, seq_world_size, scatter_dim) + ] output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op) if async_op: @@ -153,12 +160,18 @@ def wait(): return torch.cat(output_list, dim=gather_dim).contiguous() -def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False): +def all_gather_tensor( + local_tensor: Tensor, + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False, +): group = get_ulysses_sequence_parallel_group() if group is None else group sp_world_size = dist.get_world_size(group=group) output_shape = list(local_tensor.shape) output_shape[0] = output_shape[0] * sp_world_size - output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device) + output = torch.empty( + output_shape, dtype=local_tensor.dtype, device=local_tensor.device + ) dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op) return output @@ -187,7 +200,9 @@ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: input_t = grad_output[0] return ( None, - all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False), + all_to_all_tensor( + input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False + ), None, None, None, @@ -230,7 +245,9 @@ def backward(ctx: Any, grad_output: Tensor) -> Any: grad_output = grad_output * ctx.sp_world_size return ( None, - grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(), + grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ + ctx.sp_rank + ].contiguous(), None, None, None, @@ -252,7 +269,9 @@ def gather_outputs_and_unpad( return x x = Gather.apply(group, x, gather_dim, grad_scaler) if unpad_dim is not None: - assert isinstance(padding_size, int), "padding size is not given or is not an integer" + assert isinstance( + padding_size, int + ), "padding size is not given or is not an integer" if padding_size == 0: return x x = _unpad_tensor(x, unpad_dim, padding_size) @@ -260,7 +279,9 @@ def gather_outputs_and_unpad( def ulysses_pad_and_slice_inputs( - input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1 + input_ids_rmpad: torch.Tensor, + position_ids_rmpad: Optional[torch.Tensor] = None, + sp_size: int = 1, ): """ Pad and slice input_ids to be divisible by sp_size @@ -289,9 +310,13 @@ def ulysses_pad_and_slice_inputs( _, total_seq_len = input_ids_rmpad.shape pad_size = (sp_size - total_seq_len % sp_size) % sp_size if pad_size > 0: - input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0) + input_ids_rmpad = torch.nn.functional.pad( + input_ids_rmpad, (0, pad_size), value=0 + ) if position_ids_rmpad is not None: - pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0) + pad_pos_ids = torch.arange( + pad_size, device=position_ids_rmpad.device + ).unsqueeze(0) position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1) # we don't need to slice position ids input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False) diff --git a/Agent0/curriculum_train/verl/workers/actor/config.py b/Agent0/curriculum_train/verl/workers/actor/config.py index e792bc4..9f591cf 100644 --- a/Agent0/curriculum_train/verl/workers/actor/config.py +++ b/Agent0/curriculum_train/verl/workers/actor/config.py @@ -33,7 +33,9 @@ def post_init(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path - if self.model_path is not None and os.path.exists(self.model_path): # ray job uses absolute path + if self.model_path is not None and os.path.exists( + self.model_path + ): # ray job uses absolute path self.model_path = os.path.abspath(self.model_path) if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path): diff --git a/Agent0/curriculum_train/verl/workers/actor/dp_actor.py b/Agent0/curriculum_train/verl/workers/actor/dp_actor.py index 6b771ba..e8a8052 100644 --- a/Agent0/curriculum_train/verl/workers/actor/dp_actor.py +++ b/Agent0/curriculum_train/verl/workers/actor/dp_actor.py @@ -24,7 +24,11 @@ from ray.experimental.tqdm_ray import tqdm from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from transformers.modeling_flash_attention_utils import index_first_axis, pad_input, unpad_input +from transformers.modeling_flash_attention_utils import ( + index_first_axis, + pad_input, + unpad_input, +) from ...protocol import DataProto from ...trainer import core_algos @@ -53,11 +57,15 @@ def __init__( self.actor_module = actor_module self.actor_optimizer = actor_optimizer if config.use_torch_compile: - self.log_probs_from_logits = torch.compile(VF.log_probs_from_logits, dynamic=True) + self.log_probs_from_logits = torch.compile( + VF.log_probs_from_logits, dynamic=True + ) else: self.log_probs_from_logits = VF.log_probs_from_logits - def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor], temperature: float) -> torch.Tensor: + def _forward_micro_batch( + self, micro_batch: Dict[str, torch.Tensor], temperature: float + ) -> torch.Tensor: """ Returns: log_probs: # (bs, response_len) @@ -69,7 +77,9 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor], temperature responses = micro_batch["responses"] response_length = responses.size(-1) if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) + position_ids = position_ids.transpose( + 0, 1 + ) # (bsz, 3, seqlen) -> (3, bsz, seqlen) multi_modal_inputs = {} if "multi_modal_inputs" in micro_batch: @@ -87,28 +97,41 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor], temperature # unpad the position_ids to align the rotary if position_ids.dim() == 3: position_ids_rmpad = ( - index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + index_first_axis( + rearrange(position_ids, "c b s ... -> (b s) c ..."), indices + ) .transpose(0, 1) .unsqueeze(1) ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) else: position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices, ).transpose(0, 1) # for compute the log_prob - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + input_ids_rmpad_rolled = torch.roll( + input_ids_rmpad, shifts=-1, dims=1 + ) # (1, total_nnz) # pad and slice the inputs if sp > 1 if self.config.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.config.ulysses_sequence_parallel_size + input_ids_rmpad, position_ids_rmpad, pad_size = ( + ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad, + sp_size=self.config.ulysses_sequence_parallel_size, + ) ) input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, None, self.config.ulysses_sequence_parallel_size + input_ids_rmpad_rolled, + None, + self.config.ulysses_sequence_parallel_size, ) - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( + 0 + ) # ((total_nnz / sp) + pad) # only pass input_ids and position_ids to enable flash_attn_varlen output = self.actor_module( @@ -121,18 +144,27 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor], temperature logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) logits_rmpad.div_(temperature) # ((total_nnz / sp) + pad) - log_probs = self.log_probs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + log_probs = self.log_probs_from_logits( + logits=logits_rmpad, labels=input_ids_rmpad_rolled + ) # gather log_prob if sp > 1 if self.config.ulysses_sequence_parallel_size > 1: # gather and unpad for the ulysses sp - log_probs = gather_outputs_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size) + log_probs = gather_outputs_and_unpad( + log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) # pad back to (bsz, seqlen) full_log_probs = pad_input( - hidden_states=log_probs.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + hidden_states=log_probs.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, ) - log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + log_probs = full_log_probs.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) else: output = self.actor_module( input_ids=input_ids, @@ -143,8 +175,12 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor], temperature ) logits: torch.Tensor = output.logits logits.div_(temperature) - logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size) - log_probs = self.log_probs_from_logits(logits, responses) # (bsz, response_length) + logits = logits[ + :, -response_length - 1 : -1, : + ] # (bsz, response_length, vocab_size) + log_probs = self.log_probs_from_logits( + logits, responses + ) # (bsz, response_length) return log_probs @@ -152,7 +188,9 @@ def _optimizer_step(self) -> torch.Tensor: if isinstance(self.actor_module, FSDP): grad_norm = self.actor_module.clip_grad_norm_(self.config.max_grad_norm) else: - grad_norm = nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.max_grad_norm) + grad_norm = nn.utils.clip_grad_norm_( + self.actor_module.parameters(), max_norm=self.config.max_grad_norm + ) if not torch.isfinite(grad_norm): print("Gradient norm is not finite. Skip update.") @@ -208,8 +246,17 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: def update_policy(self, data: DataProto) -> Dict[str, Any]: self.actor_module.train() - temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error - select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid slient error + select_keys = [ + "responses", + "input_ids", + "attention_mask", + "position_ids", + "old_log_probs", + "advantages", + ] if self.config.use_kl_loss and not self.config.disable_kl: select_keys.append("ref_log_probs") @@ -220,7 +267,9 @@ def update_policy(self, data: DataProto) -> Dict[str, Any]: # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device) + mini_batches = data.select(select_keys, non_tensor_select_keys).split( + self.config.global_batch_size_per_device + ) metrics = defaultdict(list) for _ in range(self.config.ppo_epochs): @@ -229,11 +278,16 @@ def update_policy(self, data: DataProto) -> Dict[str, Any]: for mini_batch in mini_batches: gradient_accumulation = ( - self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update + self.config.global_batch_size_per_device + // self.config.micro_batch_size_per_device_for_update + ) + micro_batches = mini_batch.split( + self.config.micro_batch_size_per_device_for_update ) - micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update) if self.rank == 0: - micro_batches = tqdm(micro_batches, desc="Update policy", position=3) + micro_batches = tqdm( + micro_batches, desc="Update policy", position=3 + ) for micro_batch in micro_batches: model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} @@ -245,17 +299,23 @@ def update_policy(self, data: DataProto) -> Dict[str, Any]: advantages = model_inputs["advantages"] # all return: (bsz, response_length) - log_probs = self._forward_micro_batch(model_inputs, temperature=temperature) - entropy_loss = -VF.masked_mean(log_probs, response_mask) # estimator of entropy loss - - pg_loss, pg_clipfrac_higher, pg_clipfrac_lower, ppo_kl = core_algos.compute_policy_loss( - old_log_probs=old_log_probs, - log_probs=log_probs, - advantages=advantages, - response_mask=response_mask, - clip_ratio_low=self.config.clip_ratio_low, - clip_ratio_high=self.config.clip_ratio_high, - clip_ratio_dual=self.config.clip_ratio_dual, + log_probs = self._forward_micro_batch( + model_inputs, temperature=temperature + ) + entropy_loss = -VF.masked_mean( + log_probs, response_mask + ) # estimator of entropy loss + + pg_loss, pg_clipfrac_higher, pg_clipfrac_lower, ppo_kl = ( + core_algos.compute_policy_loss( + old_log_probs=old_log_probs, + log_probs=log_probs, + advantages=advantages, + response_mask=response_mask, + clip_ratio_low=self.config.clip_ratio_low, + clip_ratio_high=self.config.clip_ratio_high, + clip_ratio_dual=self.config.clip_ratio_dual, + ) ) if "ref_log_probs" in model_inputs: ref_log_probs = model_inputs["ref_log_probs"] diff --git a/Agent0/curriculum_train/verl/workers/config.py b/Agent0/curriculum_train/verl/workers/config.py index ba21b0e..422c823 100644 --- a/Agent0/curriculum_train/verl/workers/config.py +++ b/Agent0/curriculum_train/verl/workers/config.py @@ -46,7 +46,11 @@ class WorkerConfig: rollout: RolloutConfig = field(default_factory=RolloutConfig) def post_init(self): - self.ref.micro_batch_size_per_device_for_experience = self.actor.micro_batch_size_per_device_for_experience + self.ref.micro_batch_size_per_device_for_experience = ( + self.actor.micro_batch_size_per_device_for_experience + ) self.ref.padding_free = self.actor.padding_free - self.ref.ulysses_sequence_parallel_size = self.actor.ulysses_sequence_parallel_size + self.ref.ulysses_sequence_parallel_size = ( + self.actor.ulysses_sequence_parallel_size + ) self.ref.use_torch_compile = self.actor.use_torch_compile diff --git a/Agent0/curriculum_train/verl/workers/critic/dp_critic.py b/Agent0/curriculum_train/verl/workers/critic/dp_critic.py index 013c8e5..4612813 100644 --- a/Agent0/curriculum_train/verl/workers/critic/dp_critic.py +++ b/Agent0/curriculum_train/verl/workers/critic/dp_critic.py @@ -34,7 +34,12 @@ try: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + from flash_attn.bert_padding import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) except ImportError: pass @@ -43,13 +48,20 @@ class DataParallelPPOCritic(BasePPOCritic): - def __init__(self, config: CriticConfig, critic_module: nn.Module, critic_optimizer: torch.optim.Optimizer): + def __init__( + self, + config: CriticConfig, + critic_module: nn.Module, + critic_optimizer: torch.optim.Optimizer, + ): super().__init__(config) self.rank = int(os.getenv("RANK", "0")) self.critic_module = critic_module self.critic_optimizer = critic_optimizer - def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Tensor: + def _forward_micro_batch( + self, micro_batch: Dict[str, torch.Tensor] + ) -> torch.Tensor: input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -57,7 +69,9 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Te responses = micro_batch["responses"] response_length = responses.size(-1) if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) + position_ids = position_ids.transpose( + 0, 1 + ) # (bsz, 3, seqlen) -> (3, bsz, seqlen) multi_modal_inputs = {} if "multi_modal_inputs" in micro_batch: @@ -75,19 +89,26 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Te # unpad the position_ids to align the rotary if position_ids.dim() == 3: position_ids_rmpad = ( - index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + index_first_axis( + rearrange(position_ids, "c b s ... -> (b s) c ..."), indices + ) .transpose(0, 1) .unsqueeze(1) ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) else: position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices, ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.config.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.config.ulysses_sequence_parallel_size + input_ids_rmpad, position_ids_rmpad, pad_size = ( + ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad, + sp_size=self.config.ulysses_sequence_parallel_size, + ) ) # only pass input_ids and position_ids to enable flash_attn_varlen @@ -103,10 +124,14 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Te # gather output if sp > 1 if self.config.ulysses_sequence_parallel_size > 1: - values_rmpad = gather_outputs_and_unpad(values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size) + values_rmpad = gather_outputs_and_unpad( + values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) # pad it back - values = pad_input(values_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) + values = pad_input( + values_rmpad, indices=indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1) values = values[:, -response_length - 1 : -1] else: output = self.critic_module( @@ -117,7 +142,9 @@ def _forward_micro_batch(self, micro_batch: Dict[str, torch.Tensor]) -> torch.Te use_cache=False, ) values: torch.Tensor = output.logits - values = values[:, -response_length - 1 : -1].squeeze(-1) # (bsz, response_length, vocab_size) + values = values[:, -response_length - 1 : -1].squeeze( + -1 + ) # (bsz, response_length, vocab_size) return values @@ -169,7 +196,14 @@ def compute_values(self, data: DataProto) -> torch.Tensor: def update_critic(self, data: DataProto) -> Dict[str, Any]: self.critic_module.train() - select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] + select_keys = [ + "input_ids", + "responses", + "attention_mask", + "position_ids", + "values", + "returns", + ] if "multi_modal_inputs" in data.non_tensor_batch.keys(): non_tensor_select_keys = ["multi_modal_inputs"] else: @@ -177,7 +211,9 @@ def update_critic(self, data: DataProto) -> Dict[str, Any]: # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - mini_batches = data.select(select_keys, non_tensor_select_keys).split(self.config.global_batch_size_per_device) + mini_batches = data.select(select_keys, non_tensor_select_keys).split( + self.config.global_batch_size_per_device + ) metrics = defaultdict(list) for _ in range(self.config.ppo_epochs): @@ -186,11 +222,16 @@ def update_critic(self, data: DataProto) -> Dict[str, Any]: for mini_batch in mini_batches: gradient_accumulation = ( - self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update + self.config.global_batch_size_per_device + // self.config.micro_batch_size_per_device_for_update + ) + micro_batches = mini_batch.split( + self.config.micro_batch_size_per_device_for_update ) - micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update) if self.rank == 0: - micro_batches = tqdm(micro_batches, desc="Update critic", position=3) + micro_batches = tqdm( + micro_batches, desc="Update critic", position=3 + ) for micro_batch in micro_batches: model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} @@ -199,7 +240,9 @@ def update_critic(self, data: DataProto) -> Dict[str, Any]: values = model_inputs["values"] returns = model_inputs["returns"] response_length = responses.size(1) - action_mask = attention_mask[:, -response_length - 1 : -1] # shift left for value computation + action_mask = attention_mask[ + :, -response_length - 1 : -1 + ] # shift left for value computation vpreds = self._forward_micro_batch(model_inputs) vf_loss, vf_clipfrac = core_algos.compute_value_loss( @@ -215,7 +258,9 @@ def update_critic(self, data: DataProto) -> Dict[str, Any]: batch_metrics = { "critic/vf_loss": vf_loss.detach().item(), "critic/vf_clipfrac": vf_clipfrac.detach().item(), - "critic/vpred_mean": VF.masked_mean(vpreds, action_mask).detach().item(), + "critic/vpred_mean": VF.masked_mean(vpreds, action_mask) + .detach() + .item(), } append_to_dict(metrics, batch_metrics) diff --git a/Agent0/curriculum_train/verl/workers/fsdp_workers.py b/Agent0/curriculum_train/verl/workers/fsdp_workers.py index 17c65a9..378f838 100644 --- a/Agent0/curriculum_train/verl/workers/fsdp_workers.py +++ b/Agent0/curriculum_train/verl/workers/fsdp_workers.py @@ -55,8 +55,19 @@ from ..utils.model_utils import print_gpu_memory_usage, print_model_size from ..utils.tokenizer import get_processor, get_tokenizer from ..utils.torch_dtypes import PrecisionType -from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_with_warmup -from .config import ActorConfig, CriticConfig, FSDPConfig, ModelConfig, OptimConfig, RefConfig, WorkerConfig +from ..utils.torch_functional import ( + AnyPrecisionAdamW, + get_constant_schedule_with_warmup, +) +from .config import ( + ActorConfig, + CriticConfig, + FSDPConfig, + ModelConfig, + OptimConfig, + RefConfig, + WorkerConfig, +) from .rollout import vLLMRollout from .sharding_manager import FSDPVLLMShardingManager from .sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager @@ -66,7 +77,9 @@ class FSDPWorker(Worker): def __init__( self, config: WorkerConfig, - role: Literal["actor", "critic", "rollout", "ref", "actor_rollout", "actor_rollout_ref"], + role: Literal[ + "actor", "critic", "rollout", "ref", "actor_rollout", "actor_rollout_ref" + ], ): super().__init__() self.config = config @@ -81,7 +94,11 @@ def __init__( self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] self._is_critic = self.role == "critic" - self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] + self._is_rollout = self.role in [ + "rollout", + "actor_rollout", + "actor_rollout_ref", + ] self._is_ref = self.role in ["ref", "actor_rollout_ref"] self._cache = {} @@ -95,20 +112,28 @@ def __init__( self._use_param_offload = self.config.critic.offload.offload_params self._use_optimizer_offload = self.config.critic.offload.offload_optimizer self._init_config(self.config.critic, "critic") - elif self._is_ref: # NOTE: it seems that manual offload is slower than FSDP offload + elif ( + self._is_ref + ): # NOTE: it seems that manual offload is slower than FSDP offload self._use_param_offload = self.config.ref.offload.offload_params self._init_config(self.config.ref, "ref") def _init_config( - self, config: Union[ActorConfig, CriticConfig, RefConfig], role: Literal["actor", "critic", "ref"] + self, + config: Union[ActorConfig, CriticConfig, RefConfig], + role: Literal["actor", "critic", "ref"], ): world_size = dist.get_world_size() fsdp_size = config.fsdp.fsdp_size if fsdp_size <= 0 or fsdp_size >= world_size: - self.device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) + self.device_mesh = init_device_mesh( + "cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",) + ) else: # hsdp self.device_mesh = init_device_mesh( - "cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=("ddp", "fsdp") + "cuda", + mesh_shape=(world_size // fsdp_size, fsdp_size), + mesh_dim_names=("ddp", "fsdp"), ) if config.ulysses_sequence_parallel_size > 1: @@ -123,29 +148,46 @@ def _init_config( else: self.ulysses_device_mesh = None - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager( + self.ulysses_device_mesh + ) if not hasattr(config, "global_batch_size"): # ref model return if self.config.rollout.n > 1: config.global_batch_size *= self.config.rollout.n - self.print_rank0(f"{role} will use global batch size {config.global_batch_size}.") + self.print_rank0( + f"{role} will use global batch size {config.global_batch_size}." + ) config.global_batch_size_per_device = ( - config.global_batch_size // self.device_mesh.size() * config.ulysses_sequence_parallel_size + config.global_batch_size + // self.device_mesh.size() + * config.ulysses_sequence_parallel_size ) if config.global_batch_size_per_device == 0: - raise ValueError(f"{role} global batch size * ulysses size must be larger than num gpus.") + raise ValueError( + f"{role} global batch size * ulysses size must be larger than num gpus." + ) - if config.global_batch_size_per_device % config.micro_batch_size_per_device_for_update != 0: - raise ValueError(f"{role} global batch size per device must be divisible by the micro batch size.") + if ( + config.global_batch_size_per_device + % config.micro_batch_size_per_device_for_update + != 0 + ): + raise ValueError( + f"{role} global batch size per device must be divisible by the micro batch size." + ) if ( config.fsdp.enable_cpu_offload - and config.global_batch_size_per_device != config.micro_batch_size_per_device_for_update + and config.global_batch_size_per_device + != config.micro_batch_size_per_device_for_update ): - raise ValueError(f"{role} cannot use FSDP's CPU offload when gradient accumulation is enabled.") + raise ValueError( + f"{role} cannot use FSDP's CPU offload when gradient accumulation is enabled." + ) def _build_model_optimizer( self, @@ -174,9 +216,13 @@ def _build_model_optimizer( ) try: - self.generation_config = GenerationConfig.from_pretrained(model_config.model_path) + self.generation_config = GenerationConfig.from_pretrained( + model_config.model_path + ) except Exception: - self.generation_config = GenerationConfig.from_model_config(self.model_config) + self.generation_config = GenerationConfig.from_model_config( + self.model_config + ) self.print_rank0(f"Model config: {self.model_config}") @@ -185,7 +231,9 @@ def _build_model_optimizer( self.print_rank0("Ulysses patch applied!") if fsdp_config.torch_dtype is None: - torch_dtype = torch.float32 if self._is_actor or self._is_critic else torch.bfloat16 + torch_dtype = ( + torch.float32 if self._is_actor or self._is_critic else torch.bfloat16 + ) else: torch_dtype = PrecisionType.to_dtype(fsdp_config.torch_dtype) @@ -196,11 +244,13 @@ def _build_model_optimizer( else: auto_class = AutoModelForCausalLM - if (not fsdp_config.enable_rank0_init) or self.device_mesh.get_local_rank("fsdp") == 0: + if (not fsdp_config.enable_rank0_init) or self.device_mesh.get_local_rank( + "fsdp" + ) == 0: model = auto_class.from_pretrained( model_config.model_path, config=self.model_config, - torch_dtype='bfloat16', + torch_dtype="bfloat16", attn_implementation="flash_attention_2", device_map="cpu" if fsdp_config.enable_rank0_init else "cuda", low_cpu_mem_usage=True, @@ -210,7 +260,7 @@ def _build_model_optimizer( with no_init_weights(), init_empty_weights(): model = auto_class.from_config( self.model_config, - torch_dtype='bfloat16', + torch_dtype="bfloat16", attn_implementation="flash_attention_2", trust_remote_code=model_config.trust_remote_code, ) @@ -219,7 +269,9 @@ def _build_model_optimizer( model.tie_weights() # avoid hanging model = model.to(torch_dtype) if model_config.enable_gradient_checkpointing: - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) if not (self._is_actor or self._is_critic): model.requires_grad_(False) @@ -261,7 +313,9 @@ def _build_model_optimizer( if fsdp_config.enable_rank0_init: sync_module_states = True - param_init_fn = get_init_fn(model, device="cuda") if self.rank != 0 else None + param_init_fn = ( + get_init_fn(model, device="cuda") if self.rank != 0 else None + ) else: sync_module_states = False param_init_fn = None @@ -298,9 +352,13 @@ def _build_model_optimizer( weight_decay=optim_config.weight_decay, ) else: - raise NotImplementedError(f"Optimizer {optim_config.strategy} not supported.") + raise NotImplementedError( + f"Optimizer {optim_config.strategy} not supported." + ) - num_warmup_steps = int(optim_config.lr_warmup_ratio * optim_config.training_steps) + num_warmup_steps = int( + optim_config.lr_warmup_ratio * optim_config.training_steps + ) self.lr_scheduler = get_constant_schedule_with_warmup( optimizer=self.optimizer, num_warmup_steps=num_warmup_steps ) @@ -311,10 +369,12 @@ def _build_model_optimizer( def _build_rollout(self) -> None: tp_size = self.config.rollout.tensor_parallel_size dp_size = self.world_size // tp_size - assert self.world_size % tp_size == 0, ( - f"rollout world size: {self.world_size} is not divisible by tp size: {tp_size}" + assert ( + self.world_size % tp_size == 0 + ), f"rollout world size: {self.world_size} is not divisible by tp size: {tp_size}" + rollout_device_mesh = init_device_mesh( + "cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp") ) - rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp")) self.rollout = vLLMRollout( model_path=self.config.actor.model.model_path, config=self.config.rollout, @@ -400,7 +460,9 @@ def init_model(self): model=self.fsdp_module, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), ) @register(dispatch_mode=Dispatch.ONE_TO_ALL) @@ -436,7 +498,7 @@ def preprocess_multi_modal_data(self, data: DataProto): processed_images = [] for multi_modal_data in multi_modal_data_copy: processed_per_query_images = [] - for image in multi_modal_data['image']: + for image in multi_modal_data["image"]: processed_per_query_images.append( process_image(image, min_pixels=min_pixels, max_pixels=max_pixels) ) @@ -454,17 +516,24 @@ def preprocess_multi_modal_data(self, data: DataProto): # for j, image in enumerate(per_query_images): # images[i][j] = process_image(image, min_pixels=min_pixels, max_pixels=max_pixels) - multi_modal_inputs = np.array([ - dict(self.processor.image_processor(images=per_query_images, videos=None)) - for per_query_images in processed_images - ], dtype=object) + multi_modal_inputs = np.array( + [ + dict( + self.processor.image_processor(images=per_query_images, videos=None) + ) + for per_query_images in processed_images + ], + dtype=object, + ) data.non_tensor_batch["multi_modal_inputs"] = multi_modal_inputs @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): assert self._is_actor if "multi_modal_inputs" in self._cache: - data.non_tensor_batch['multi_modal_inputs'] = deepcopy(self._cache['multi_modal_inputs']) + data.non_tensor_batch["multi_modal_inputs"] = deepcopy( + self._cache["multi_modal_inputs"] + ) elif "multi_modal_data" in data.non_tensor_batch: self.preprocess_multi_modal_data(data) @@ -483,17 +552,25 @@ def update_actor(self, data: DataProto): delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + estimated_flops, promised_flops = self.flops_counter.estimate_flops( + global_num_tokens, delta_time + ) metrics["perf/mfu_actor"] = ( - estimated_flops * self.config.actor.ppo_epochs / (promised_flops * self.world_size) + estimated_flops + * self.config.actor.ppo_epochs + / (promised_flops * self.world_size) ) metrics["perf/max_memory_allocated_gb"] = ( - torch.cuda.max_memory_allocated() - self.rollout_sharding_manager.freed_bytes + torch.cuda.max_memory_allocated() + - self.rollout_sharding_manager.freed_bytes ) / (1024**3) metrics["perf/max_memory_reserved_gb"] = ( - torch.cuda.max_memory_reserved() - self.rollout_sharding_manager.freed_bytes + torch.cuda.max_memory_reserved() + - self.rollout_sharding_manager.freed_bytes ) / (1024**3) - metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / ( + 1024**3 + ) self.lr_scheduler.step() lr = self.lr_scheduler.get_last_lr()[0] @@ -502,7 +579,8 @@ def update_actor(self, data: DataProto): # Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info. output = DataProto( non_tensor_batch={ - key: np.array([value] if np.isscalar(value) else value) for key, value in metrics.items() + key: np.array([value] if np.isscalar(value) else value) + for key, value in metrics.items() } ) @@ -523,12 +601,16 @@ def generate_sequences(self, prompts: DataProto): load_fsdp_model(self.fsdp_module) meta_info = { - "eos_token_id": self.generation_config.eos_token_id - if self.generation_config is not None - else self.tokenizer.eos_token_id, - "pad_token_id": self.generation_config.pad_token_id - if self.generation_config is not None - else self.tokenizer.pad_token_id, + "eos_token_id": ( + self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id + ), + "pad_token_id": ( + self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id + ), } prompts.meta_info.update(meta_info) with self.rollout_sharding_manager: @@ -544,13 +626,19 @@ def generate_sequences(self, prompts: DataProto): # load image data cached_multi_modal_data = None if "multi_modal_data" in prompts.non_tensor_batch: - cached_multi_modal_data = deepcopy(prompts.non_tensor_batch["multi_modal_data"]) - min_pixels = prompts.meta_info['min_pixels'] - max_pixels = prompts.meta_info['max_pixels'] + cached_multi_modal_data = deepcopy( + prompts.non_tensor_batch["multi_modal_data"] + ) + min_pixels = prompts.meta_info["min_pixels"] + max_pixels = prompts.meta_info["max_pixels"] processed_images = [] - for i, multi_modal_data in enumerate(prompts.non_tensor_batch["multi_modal_data"]): + for i, multi_modal_data in enumerate( + prompts.non_tensor_batch["multi_modal_data"] + ): for j, image in enumerate(multi_modal_data["image"]): - multi_modal_data['image'][j] = process_image(image, min_pixels=min_pixels, max_pixels=max_pixels) + multi_modal_data["image"][j] = process_image( + image, min_pixels=min_pixels, max_pixels=max_pixels + ) processed_images.append(multi_modal_data) prompts.non_tensor_batch["multi_modal_data"] = processed_images @@ -562,7 +650,9 @@ def generate_sequences(self, prompts: DataProto): output.non_tensor_batch["multi_modal_data"] = cached_multi_modal_data if sampling_n > 1: output.non_tensor_batch["multi_modal_data"] = np.repeat( - output.non_tensor_batch["multi_modal_data"], repeats=sampling_n, axis=0, + output.non_tensor_batch["multi_modal_data"], + repeats=sampling_n, + axis=0, ) output = self.rollout_sharding_manager.postprocess_data(output) @@ -577,7 +667,9 @@ def compute_log_probs(self, data: DataProto): if "multi_modal_data" in data.non_tensor_batch: self.preprocess_multi_modal_data(data) # create cache for multi_modal_inputs - self._cache['multi_modal_inputs'] = deepcopy(data.non_tensor_batch['multi_modal_inputs']) + self._cache["multi_modal_inputs"] = deepcopy( + data.non_tensor_batch["multi_modal_inputs"] + ) data = data.to(torch.cuda.current_device()) if self._use_param_offload: @@ -590,7 +682,8 @@ def compute_log_probs(self, data: DataProto): data = self.ulysses_sharding_manager.preprocess_data(data) output = self.actor.compute_log_prob(data=data) output = DataProto.from_dict( - tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature} + tensors={"old_log_probs": output}, + meta_info={"temperature": self.config.rollout.temperature}, ) output = self.ulysses_sharding_manager.postprocess_data(output) @@ -611,7 +704,9 @@ def compute_ref_log_probs(self, data: DataProto): # not in the ref_policy's or critic's caches. assert self._is_ref if "multi_modal_inputs" in self._cache: - data.non_tensor_batch['multi_modal_inputs'] = deepcopy(self._cache['multi_modal_inputs']) + data.non_tensor_batch["multi_modal_inputs"] = deepcopy( + self._cache["multi_modal_inputs"] + ) elif "multi_modal_data" in data.non_tensor_batch: self.preprocess_multi_modal_data(data) @@ -643,7 +738,9 @@ def compute_values(self, data: DataProto): # The `self._cache` is empty here since cached `multi_modal_inputs` is only saved in the actor's _cache, # not in the ref_policy's or critic's caches. if "multi_modal_inputs" in self._cache: - data.non_tensor_batch['multi_modal_inputs'] = deepcopy(self._cache['multi_modal_inputs']) + data.non_tensor_batch["multi_modal_inputs"] = deepcopy( + self._cache["multi_modal_inputs"] + ) elif "multi_modal_data" in data.non_tensor_batch: self.preprocess_multi_modal_data(data) @@ -668,7 +765,9 @@ def update_critic(self, data: DataProto): # The `self._cache` is empty here since cached `multi_modal_inputs` is only saved in the actor's _cache, # not in the ref_policy's or critic's caches. if "multi_modal_inputs" in self._cache: - data.non_tensor_batch['multi_modal_inputs'] = deepcopy(self._cache['multi_modal_inputs']) + data.non_tensor_batch["multi_modal_inputs"] = deepcopy( + self._cache["multi_modal_inputs"] + ) elif "multi_modal_data" not in data.non_tensor_batch: self.preprocess_multi_modal_data(data) @@ -686,9 +785,13 @@ def update_critic(self, data: DataProto): delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + estimated_flops, promised_flops = self.flops_counter.estimate_flops( + global_num_tokens, delta_time + ) metrics["perf/mfu_critic"] = ( - estimated_flops * self.config.actor.ppo_epochs / (promised_flops * self.world_size) + estimated_flops + * self.config.actor.ppo_epochs + / (promised_flops * self.world_size) ) self.lr_scheduler.step() @@ -698,7 +801,8 @@ def update_critic(self, data: DataProto): # Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info. output = DataProto( non_tensor_batch={ - metric: np.array([value] if np.isscalar(value) else value) for metric, value in metrics.items() + metric: np.array([value] if np.isscalar(value) else value) + for metric, value in metrics.items() } ) diff --git a/Agent0/curriculum_train/verl/workers/reward/__init__.py b/Agent0/curriculum_train/verl/workers/reward/__init__.py index 9d476f6..d9227ec 100644 --- a/Agent0/curriculum_train/verl/workers/reward/__init__.py +++ b/Agent0/curriculum_train/verl/workers/reward/__init__.py @@ -13,7 +13,16 @@ # limitations under the License. from .config import RewardConfig -from .function import BatchFunctionRewardManager, FunctionRewardManager, SequentialFunctionRewardManager +from .function import ( + BatchFunctionRewardManager, + FunctionRewardManager, + SequentialFunctionRewardManager, +) -__all__ = ["BatchFunctionRewardManager", "FunctionRewardManager", "RewardConfig", "SequentialFunctionRewardManager"] +__all__ = [ + "BatchFunctionRewardManager", + "FunctionRewardManager", + "RewardConfig", + "SequentialFunctionRewardManager", +] diff --git a/Agent0/curriculum_train/verl/workers/reward/config.py b/Agent0/curriculum_train/verl/workers/reward/config.py index 7e11bdb..7620660 100644 --- a/Agent0/curriculum_train/verl/workers/reward/config.py +++ b/Agent0/curriculum_train/verl/workers/reward/config.py @@ -31,11 +31,15 @@ class RewardConfig: reward_function_name: Optional[str] = field(default=None, init=False) def post_init(self): - if self.reward_function is not None: # support custom reward function, e.g., ./math.py:main + if ( + self.reward_function is not None + ): # support custom reward function, e.g., ./math.py:main if ":" not in self.reward_function: self.reward_function_name = "main" else: - self.reward_function, self.reward_function_name = self.reward_function.rsplit(":", maxsplit=1) + self.reward_function, self.reward_function_name = ( + self.reward_function.rsplit(":", maxsplit=1) + ) if os.path.exists(self.reward_function): # ray job uses absolute path self.reward_function = os.path.abspath(self.reward_function) diff --git a/Agent0/curriculum_train/verl/workers/reward/function.py b/Agent0/curriculum_train/verl/workers/reward/function.py index a7af022..f47c6b9 100644 --- a/Agent0/curriculum_train/verl/workers/reward/function.py +++ b/Agent0/curriculum_train/verl/workers/reward/function.py @@ -46,9 +46,13 @@ def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer): raise ValueError("Reward function is not provided.") if not os.path.exists(config.reward_function): - raise FileNotFoundError(f"Reward function file {config.reward_function} not found.") + raise FileNotFoundError( + f"Reward function file {config.reward_function} not found." + ) - spec = importlib.util.spec_from_file_location("custom_reward_fn", config.reward_function) + spec = importlib.util.spec_from_file_location( + "custom_reward_fn", config.reward_function + ) module = importlib.util.module_from_spec(spec) try: sys.modules["custom_reward_fn"] = module @@ -57,16 +61,22 @@ def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer): raise RuntimeError(f"Failed to load reward function: {e}") if not hasattr(module, config.reward_function_name): - raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.") + raise AttributeError( + f"Module {module} does not have function {config.reward_function_name}." + ) reward_fn = getattr(module, config.reward_function_name) - print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.") + print( + f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`." + ) self.reward_fn = partial(reward_fn, **config.reward_function_kwargs) self.config = config self.tokenizer = tokenizer @abstractmethod - def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: + def compute_reward( + self, data: DataProto + ) -> Tuple[torch.Tensor, Dict[str, List[float]]]: """Compute reward for a batch of data.""" ... @@ -74,7 +84,9 @@ def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[ class SequentialFunctionRewardManager(FunctionRewardManager): reward_fn: SequentialRewardFunction - def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: + def compute_reward( + self, data: DataProto + ) -> Tuple[torch.Tensor, Dict[str, List[float]]]: reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) reward_metrics = defaultdict(list) response_ids = data.batch["responses"] @@ -97,14 +109,19 @@ def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[ class BatchFunctionRewardManager(FunctionRewardManager): reward_fn: BatchRewardFunction - def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: + def compute_reward( + self, data: DataProto + ) -> Tuple[torch.Tensor, Dict[str, List[float]]]: response_str, ground_truth = [], [] response_ids = data.batch["responses"] response_length = data.batch["response_mask"].sum(dim=-1) for i in range(len(data)): valid_response_ids = response_ids[i][: response_length[i]] response_str.append( - self.tokenizer.decode(valid_response_ids, skip_special_tokens=self.config.skip_special_tokens) + self.tokenizer.decode( + valid_response_ids, + skip_special_tokens=self.config.skip_special_tokens, + ) ) ground_truth.append(data.non_tensor_batch["ground_truth"][i]) diff --git a/Agent0/curriculum_train/verl/workers/rollout/vllm_rollout_spmd.py b/Agent0/curriculum_train/verl/workers/rollout/vllm_rollout_spmd.py index 13cb4d7..4862a88 100644 --- a/Agent0/curriculum_train/verl/workers/rollout/vllm_rollout_spmd.py +++ b/Agent0/curriculum_train/verl/workers/rollout/vllm_rollout_spmd.py @@ -31,24 +31,34 @@ from .config import RolloutConfig import traceback -def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]: + + +def _repeat_interleave( + value: Union[torch.Tensor, np.ndarray], repeats: int +) -> Union[torch.Tensor, List[Any]]: if isinstance(value, torch.Tensor): return value.repeat_interleave(repeats, dim=0) else: return np.repeat(value, repeats, axis=0) -def _get_logit_bias(model_path: str, trust_remote_code: bool) -> Optional[Dict[int, float]]: +def _get_logit_bias( + model_path: str, trust_remote_code: bool +) -> Optional[Dict[int, float]]: processor = get_processor(model_path, trust_remote_code=trust_remote_code) if processor is not None and hasattr(processor, "image_token"): - image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) + image_token_id = processor.tokenizer.convert_tokens_to_ids( + processor.image_token + ) return {image_token_id: -100} else: return None class vLLMRollout(BaseRollout): - def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer): + def __init__( + self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer + ): """A vLLM rollout. It requires the module is supported by the vllm. Args: @@ -63,8 +73,13 @@ def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrained if config.tensor_parallel_size > torch.distributed.get_world_size(): raise ValueError("Tensor parallelism size should be less than world size.") - if config.max_num_batched_tokens < config.prompt_length + config.response_length: - raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.") + if ( + config.max_num_batched_tokens + < config.prompt_length + config.response_length + ): + raise ValueError( + "max_num_batched_tokens should be greater than prompt_length + response_length." + ) engine_kwargs = {} if config.limit_images: @@ -77,7 +92,8 @@ def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrained load_format="dummy", dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)), seed=config.seed, - max_model_len=config.max_model_len or config.prompt_length + config.response_length, + max_model_len=config.max_model_len + or config.prompt_length + config.response_length, distributed_executor_backend="external_launcher", tensor_parallel_size=config.tensor_parallel_size, gpu_memory_utilization=config.gpu_memory_utilization, @@ -97,11 +113,13 @@ def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrained sampling_kwargs = { "max_tokens": config.response_length, "detokenize": False, - "logit_bias": _get_logit_bias(model_path, trust_remote_code=config.trust_remote_code), + "logit_bias": _get_logit_bias( + model_path, trust_remote_code=config.trust_remote_code + ), } default_sampling_params = SamplingParams() for key in config.to_dict().keys(): - if key == 'seed': + if key == "seed": continue if hasattr(default_sampling_params, key): sampling_kwargs[key] = getattr(config, key) @@ -144,20 +162,33 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: if "multi_modal_data" in non_tensor_batch: vllm_inputs = [] for raw_prompt_ids, multi_modal_data in zip( - non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data") + non_tensor_batch.pop("raw_prompt_ids"), + non_tensor_batch.pop("multi_modal_data"), ): - vllm_inputs.append({"prompt_token_ids": list(raw_prompt_ids), "multi_modal_data": multi_modal_data}) + vllm_inputs.append( + { + "prompt_token_ids": list(raw_prompt_ids), + "multi_modal_data": multi_modal_data, + } + ) else: vllm_inputs = [ - {"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") + {"prompt_token_ids": list(raw_prompt_ids)} + for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") ] # users can customize different sampling_params at different run with self.update_sampling_params(**prompts.meta_info): completions: List[RequestOutput] = self.inference_engine.generate( - prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=False + prompts=vllm_inputs, + sampling_params=self.sampling_params, + use_tqdm=False, ) - response_ids = [output.token_ids for completion in completions for output in completion.outputs] + response_ids = [ + output.token_ids + for completion in completions + for output in completion.outputs + ] response_ids = VF.pad_2d_list_to_length( response_ids, self.pad_token_id, max_length=self.config.response_length ).to(input_ids.device) @@ -165,15 +196,21 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: if self.sampling_params.n > 1: batch_size = batch_size * self.sampling_params.n input_ids = _repeat_interleave(input_ids, self.sampling_params.n) - attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n) + attention_mask = _repeat_interleave( + attention_mask, self.sampling_params.n + ) position_ids = _repeat_interleave(position_ids, self.sampling_params.n) sequence_ids = torch.cat([input_ids, response_ids], dim=-1) response_length = response_ids.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) + delta_position_id = torch.arange( + 1, response_length + 1, device=position_ids.device + ) delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1) if position_ids.dim() == 3: # qwen2vl mrope - delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) + delta_position_id = delta_position_id.view(batch_size, 1, -1).expand( + batch_size, 3, -1 + ) # prompt: left pad + response: right pad # attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0] @@ -181,7 +218,9 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: response_position_ids = position_ids[..., -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) response_mask = VF.get_response_mask( - response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype + response_ids=response_ids, + eos_token_id=eos_token_id, + dtype=attention_mask.dtype, ) attention_mask = torch.cat((attention_mask, response_mask), dim=-1) # all the tp ranks should contain the same data here. data in all ranks are valid diff --git a/Agent0/curriculum_train/verl/workers/sharding_manager/__init__.py b/Agent0/curriculum_train/verl/workers/sharding_manager/__init__.py index 88eaee4..cf06253 100644 --- a/Agent0/curriculum_train/verl/workers/sharding_manager/__init__.py +++ b/Agent0/curriculum_train/verl/workers/sharding_manager/__init__.py @@ -18,4 +18,8 @@ from .fsdp_vllm import FSDPVLLMShardingManager -__all__ = ["BaseShardingManager", "FSDPUlyssesShardingManager", "FSDPVLLMShardingManager"] +__all__ = [ + "BaseShardingManager", + "FSDPUlyssesShardingManager", + "FSDPVLLMShardingManager", +] diff --git a/Agent0/curriculum_train/verl/workers/sharding_manager/fsdp_ulysses.py b/Agent0/curriculum_train/verl/workers/sharding_manager/fsdp_ulysses.py index c2ce5b9..5bb3dcf 100644 --- a/Agent0/curriculum_train/verl/workers/sharding_manager/fsdp_ulysses.py +++ b/Agent0/curriculum_train/verl/workers/sharding_manager/fsdp_ulysses.py @@ -18,7 +18,10 @@ from torch.distributed.device_mesh import DeviceMesh from ...protocol import DataProto, all_gather_data_proto -from ...utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group +from ...utils.ulysses import ( + get_ulysses_sequence_parallel_group, + set_ulysses_sequence_parallel_group, +) from .base import BaseShardingManager diff --git a/Agent0/curriculum_train/verl/workers/sharding_manager/fsdp_vllm.py b/Agent0/curriculum_train/verl/workers/sharding_manager/fsdp_vllm.py index 11f1090..a2ad4d0 100644 --- a/Agent0/curriculum_train/verl/workers/sharding_manager/fsdp_vllm.py +++ b/Agent0/curriculum_train/verl/workers/sharding_manager/fsdp_vllm.py @@ -21,7 +21,9 @@ from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.state_dict import get_model_state_dict from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel as FSDP, +) from transformers import PreTrainedModel from vllm import LLM from vllm.distributed import parallel_state as vllm_ps @@ -55,20 +57,30 @@ def __init__( self.torch_random_states = torch.cuda.get_rng_state() # get a random rng states gen_dp_rank = self.device_mesh["dp"].get_local_rank() - torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + torch.cuda.manual_seed( + gen_dp_rank + 1000 + ) # make sure all tp ranks have the same random states self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) - def _rename_weight_keys(self, actor_weights: Dict[str, Union[torch.Tensor, DTensor]], model: PreTrainedModel): + def _rename_weight_keys( + self, + actor_weights: Dict[str, Union[torch.Tensor, DTensor]], + model: PreTrainedModel, + ): # convert state dict keys: https://github.com/huggingface/transformers/pull/38385 if not hasattr(model, "_checkpoint_conversion_mapping"): return actor_weights - reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()} + reverse_key_mapping = { + v: k for k, v in model._checkpoint_conversion_mapping.items() + } original_weights = {} for key, value in actor_weights.items(): for pattern, replacement in reverse_key_mapping.items(): - replacement = replacement.lstrip("^") # strip off un-needed chars and patterns + replacement = replacement.lstrip( + "^" + ) # strip off un-needed chars and patterns replacement = re.sub(r"\(.*\)", "", replacement) key, n_replace = re.subn(pattern, replacement, key) # Early exit of the loop @@ -96,7 +108,9 @@ def __enter__(self): torch.cuda.empty_cache() print_gpu_memory_usage("Before state_dict() in sharding manager") actor_weights = get_model_state_dict(self.module) - actor_weights = self._rename_weight_keys(actor_weights, self.module._fsdp_wrapped_module) + actor_weights = self._rename_weight_keys( + actor_weights, self.module._fsdp_wrapped_module + ) print_gpu_memory_usage("After state_dict() in sharding manager") if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: @@ -104,7 +118,9 @@ def __enter__(self): else: self.inference_engine.wake_up() - model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + model = ( + self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + ) model.load_weights(self._make_weight_iterator(actor_weights)) print_gpu_memory_usage("After sync model weights in sharding manager") @@ -114,7 +130,9 @@ def __enter__(self): if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["kv_cache"]) - print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager") + print_gpu_memory_usage( + "After del state_dict and empty_cache in sharding manager" + ) # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = torch.cuda.get_rng_state() diff --git a/Agent0/curriculum_train/vllm_service_init/start_vllm_server_tool.py b/Agent0/curriculum_train/vllm_service_init/start_vllm_server_tool.py index 888960b..d4c7371 100644 --- a/Agent0/curriculum_train/vllm_service_init/start_vllm_server_tool.py +++ b/Agent0/curriculum_train/vllm_service_init/start_vllm_server_tool.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -''' -This script enhances the LLM's problem-solving capabilities by integrating a code execution tool. +""" +This script enhances the LLM's problem-solving capabilities by integrating a code execution tool. It processes each question through a multi-turn conversational approach, allowing the model to generate, execute, and reason based on code output. The generation process for each of the 10 candidates is now a stateful, iterative loop. @@ -13,7 +13,7 @@ # 3. Run the server python your_server_file_name.py --port 5000 --model_path Qwen/Qwen3-4B-Base -''' +""" from flask import Flask, request, jsonify import vllm @@ -34,15 +34,16 @@ # ---------------------------- Code Execution Tool --------------------------- # SANDBOX_API_URLS = [ - 'IP1:PORT1/run_code', - 'IP2:PORT2/run_code', - 'IP3:PORT3/run_code', - 'IP4:PORT4/run_code' + "IP1:PORT1/run_code", + "IP2:PORT2/run_code", + "IP3:PORT3/run_code", + "IP4:PORT4/run_code", ] api_counter_lock = threading.Lock() api_counter = 0 + def execute_code_in_sandbox(code: str) -> str: """ Calls an external sandbox API to execute Python code, with load balancing. @@ -54,8 +55,10 @@ def execute_code_in_sandbox(code: str) -> str: try: payload = {"code": code, "language": "python"} - headers = {'Content-Type': 'application/json'} - response = requests.post(target_url, headers=headers, data=json.dumps(payload), timeout=20) + headers = {"Content-Type": "application/json"} + response = requests.post( + target_url, headers=headers, data=json.dumps(payload), timeout=20 + ) response.raise_for_status() result = response.json() @@ -65,7 +68,7 @@ def execute_code_in_sandbox(code: str) -> str: stdout = run_info.get("stdout", "") return stdout if stdout else "[No output]" else: - stderr = run_info.get('stderr', '') + stderr = run_info.get("stderr", "") return f"Execution failed with status: {run_info.get('status')}\nStderr: {stderr}" else: return f"API Error: {result}" @@ -76,14 +79,18 @@ def execute_code_in_sandbox(code: str) -> str: # ---------------------------- Initial Setup --------------------------------- # parser = argparse.ArgumentParser() -parser.add_argument('--port', type=str, default='5000') -parser.add_argument('--model_path', type=str, default='Qwen/Qwen3-4B-Base') -parser.add_argument('--gpu_mem_util', type=float, default=0.8, - help='The maximum GPU memory utilization fraction for vLLM.') +parser.add_argument("--port", type=str, default="5000") +parser.add_argument("--model_path", type=str, default="Qwen/Qwen3-4B-Base") +parser.add_argument( + "--gpu_mem_util", + type=float, + default=0.8, + help="The maximum GPU memory utilization fraction for vLLM.", +) args = parser.parse_args() -print('[init] Loading model...') +print("[init] Loading model...") tokenizer = AutoTokenizer.from_pretrained(args.model_path) model = vllm.LLM( model=args.model_path, @@ -96,7 +103,7 @@ def execute_code_in_sandbox(code: str) -> str: temperature=0.7, top_p=0.9, n=1, - stop_token_ids=[tokenizer.eos_token_id] + stop_token_ids=[tokenizer.eos_token_id], ) SYSTEM_PROMPT = ( @@ -115,8 +122,9 @@ def execute_code_in_sandbox(code: str) -> str: stop_event = threading.Event() pause_event = threading.Event() + def gpu_idle_worker(): - print('[idle_worker] GPU idle worker started.') + print("[idle_worker] GPU idle worker started.") running = True while not stop_event.is_set(): if pause_event.is_set(): @@ -128,31 +136,41 @@ def gpu_idle_worker(): if not running: running = True try: - a = torch.rand((2000, 2000), dtype=torch.float32, device='cuda') - b = torch.rand((2000, 2000), dtype=torch.float32, device='cuda') + a = torch.rand((2000, 2000), dtype=torch.float32, device="cuda") + b = torch.rand((2000, 2000), dtype=torch.float32, device="cuda") torch.matmul(a, b) torch.cuda.synchronize() except RuntimeError: time.sleep(1) - print('[idle_worker] GPU idle worker stopped.') + print("[idle_worker] GPU idle worker stopped.") + idle_thread = threading.Thread(target=gpu_idle_worker, daemon=True) idle_thread.start() + # ---------------------------- Core Logic (Refactored) ----------------------- # -@stopit.threading_timeoutable(default='TIMED_OUT') +@stopit.threading_timeoutable(default="TIMED_OUT") def grade_answer_with_timeout(res1, res2): return grade_answer(res1, res2) + sandbox_executor = ThreadPoolExecutor(max_workers=64) + def generate_with_tool_use(question: str, num_candidates: int = 10, max_turns: int = 4): """ Generates answers using a multi-turn conversation loop (up to max_turns). Handles code execution and history updates dynamically. """ # Initialize conversation history for all candidates - conversations = [[{'role': 'system', 'content': SYSTEM_PROMPT}, {'role': 'user', 'content': question}] for _ in range(num_candidates)] + conversations = [ + [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": question}, + ] + for _ in range(num_candidates) + ] final_assistant_messages = [""] * num_candidates active_indices = list(range(num_candidates)) @@ -161,8 +179,13 @@ def generate_with_tool_use(question: str, num_candidates: int = 10, max_turns: i break # Prepare prompts only for active candidates - prompts = [tokenizer.apply_chat_template(conversations[i], tokenize=False, add_generation_prompt=True) for i in active_indices] - + prompts = [ + tokenizer.apply_chat_template( + conversations[i], tokenize=False, add_generation_prompt=True + ) + for i in active_indices + ] + # Batch generate responses = model.generate(prompts, sampling_params_single_turn, use_tqdm=False) @@ -173,30 +196,36 @@ def generate_with_tool_use(question: str, num_candidates: int = 10, max_turns: i for i, response in enumerate(responses): original_index = active_indices[i] model_output = response.outputs[0].text.strip() - + # Clean up potential incomplete code blocks code_block_start_tag = "```python" code_block_end_tag = "```" start_index = model_output.find(code_block_start_tag) if start_index != -1: - end_index = model_output.find(code_block_end_tag, start_index + len(code_block_start_tag)) + end_index = model_output.find( + code_block_end_tag, start_index + len(code_block_start_tag) + ) if end_index != -1: - model_output = model_output[:end_index + len(code_block_end_tag)] - + model_output = model_output[: end_index + len(code_block_end_tag)] + # Update history - conversations[original_index].append({'role': 'assistant', 'content': model_output}) + conversations[original_index].append( + {"role": "assistant", "content": model_output} + ) # Check for Code code_match = re.search(r"```python\n(.*?)\n```", model_output, re.DOTALL) - + # Check for Boxed Answer - has_boxed = r'\boxed' in model_output + has_boxed = r"\boxed" in model_output if code_match and not has_boxed: # Found code, no final answer yet -> Queue for execution code_to_run = (code_match.group(1) or "").strip() if code_to_run: - future = sandbox_executor.submit(execute_code_in_sandbox, code_to_run) + future = sandbox_executor.submit( + execute_code_in_sandbox, code_to_run + ) tasks_to_run.append((future, original_index)) indices_with_code.add(original_index) else: @@ -206,7 +235,7 @@ def generate_with_tool_use(question: str, num_candidates: int = 10, max_turns: i # Found answer -> Mark as finished final_assistant_messages[original_index] = model_output else: - # Pure text reasoning -> Will continue to next turn if logic requires, + # Pure text reasoning -> Will continue to next turn if logic requires, # or strictly speaking, we keep it active to allow further reasoning. pass @@ -222,23 +251,25 @@ def generate_with_tool_use(question: str, num_candidates: int = 10, max_turns: i next_active_indices = [] for i, response in enumerate(responses): original_index = active_indices[i] - + # If we already found a boxed answer, this candidate is done. if final_assistant_messages[original_index]: continue - + # If it had code, append result and keep active if original_index in indices_with_code: exec_result = results_map.get(original_index, "Result not found.") tool_feedback = f"Code execution result: {exec_result}" - conversations[original_index].append({'role': 'user', 'content': tool_feedback}) + conversations[original_index].append( + {"role": "user", "content": tool_feedback} + ) next_active_indices.append(original_index) - + # If it was just text (and no boxed), we keep it active for the next turn # (assuming it needs more steps), unless it was the last turn. else: next_active_indices.append(original_index) - + active_indices = next_active_indices # Fill in any candidates that didn't finish with \boxed with their last output @@ -247,39 +278,44 @@ def generate_with_tool_use(question: str, num_candidates: int = 10, max_turns: i # Use the last assistant message as the best effort result # Traverse backwards to find the last assistant message for msg in reversed(conversations[i]): - if msg['role'] == 'assistant': - final_assistant_messages[i] = msg['content'] + if msg["role"] == "assistant": + final_assistant_messages[i] = msg["content"] break - + return final_assistant_messages def consolidate_and_grade(question, golden_answer, assistant_messages): - '''Consolidates and grades LLM outputs for a single question.''' + """Consolidates and grades LLM outputs for a single question.""" results = [extract_boxed_content(msg) for msg in assistant_messages] - + answer_counts = {} for res in results: - if not res: continue + if not res: + continue matched = False - + for exist_ans in list(answer_counts.keys()): - if res == exist_ans or ('no ' in res.lower() and 'no ' in exist_ans.lower()): + if res == exist_ans or ( + "no " in res.lower() and "no " in exist_ans.lower() + ): answer_counts[exist_ans] += 1 matched = True break - + try: is_match = False match_result_1 = grade_answer_with_timeout(res, exist_ans, timeout=20) - if match_result_1 and match_result_1 != 'TIMED_OUT': + if match_result_1 and match_result_1 != "TIMED_OUT": is_match = True if not is_match: - match_result_2 = grade_answer_with_timeout(exist_ans, res, timeout=20) - if match_result_2 and match_result_2 != 'TIMED_OUT': + match_result_2 = grade_answer_with_timeout( + exist_ans, res, timeout=20 + ) + if match_result_2 and match_result_2 != "TIMED_OUT": is_match = True - + if is_match: answer_counts[exist_ans] += 1 matched = True @@ -287,12 +323,12 @@ def consolidate_and_grade(question, golden_answer, assistant_messages): except Exception: continue - + if not matched: answer_counts[res] = 1 if not answer_counts: - majority_ans, max_count = '', 0 + majority_ans, max_count = "", 0 else: majority_ans = max(answer_counts, key=answer_counts.get) max_count = answer_counts[majority_ans] @@ -300,66 +336,88 @@ def consolidate_and_grade(question, golden_answer, assistant_messages): score = max_count / len(assistant_messages) if assistant_messages else 0.0 return { - 'question': question, - 'answer': majority_ans, - 'score': score if grade_answer(majority_ans, golden_answer) and score > 0.1 else 0, - 'all_outputs': assistant_messages, - 'extracted_results': results + "question": question, + "answer": majority_ans, + "score": ( + score if grade_answer(majority_ans, golden_answer) and score > 0.1 else 0 + ), + "all_outputs": assistant_messages, + "extracted_results": results, } + # ---------------------------- Flask Application --------------------------- # app = Flask(__name__) -@app.route('/hello', methods=['GET']) + +@app.route("/hello", methods=["GET"]) def hello(): pause_event.set() torch.cuda.synchronize() - name = request.args.get('name', 'None') - - with open(name, 'r') as f: + name = request.args.get("name", "None") + + with open(name, "r") as f: data = json.load(f) os.remove(name) - questions = [item.get('question', '') for item in data] - answers = [item.get('answer', '') for item in data] + questions = [item.get("question", "") for item in data] + answers = [item.get("answer", "") for item in data] results_all = [] - + # Using TQDM for clean progress visualization - progress_bar = tqdm(zip(questions, answers), total=len(questions), desc=f"Processing {os.path.basename(name)}") - + progress_bar = tqdm( + zip(questions, answers), + total=len(questions), + desc=f"Processing {os.path.basename(name)}", + ) + for q, a in progress_bar: try: if q and a: # Multi-turn generation final_assistant_messages = generate_with_tool_use(q, max_turns=4) - + # Consolidate and Grade item = consolidate_and_grade(q, a, final_assistant_messages) results_all.append(item) else: - results_all.append({'question': q, 'answer': a, 'score': -1, 'all_outputs': [], 'extracted_results': []}) + results_all.append( + { + "question": q, + "answer": a, + "score": -1, + "all_outputs": [], + "extracted_results": [], + } + ) except Exception as e: # Only printing critical errors to not mess up TQDM too much - print(f'\n[server] Error processing question: {str(e)}') - results_all.append({ - 'question': q, 'answer': a, 'score': -1, 'error': f'unhandled exception: {str(e)}' - }) - - out_path = name.replace('.json', '_results.json') - with open(out_path, 'w') as f: + print(f"\n[server] Error processing question: {str(e)}") + results_all.append( + { + "question": q, + "answer": a, + "score": -1, + "error": f"unhandled exception: {str(e)}", + } + ) + + out_path = name.replace(".json", "_results.json") + with open(out_path, "w") as f: json.dump(results_all, f, indent=4) pause_event.clear() - return jsonify({'message': f'Processed {name}, results saved to {out_path}.'}) + return jsonify({"message": f"Processed {name}, results saved to {out_path}."}) + # ------------------------- Main Application Entrypoint --------------------------- # -if __name__ == '__main__': +if __name__ == "__main__": try: - app.run(host='127.0.0.1', port=int(args.port), threaded=True) + app.run(host="127.0.0.1", port=int(args.port), threaded=True) finally: stop_event.set() if idle_thread.is_alive(): idle_thread.join() - print('[main] Application shutdown complete.') \ No newline at end of file + print("[main] Application shutdown complete.") diff --git a/Agent0/executor_train/eval_service/app.py b/Agent0/executor_train/eval_service/app.py index 63b347a..54e28b3 100644 --- a/Agent0/executor_train/eval_service/app.py +++ b/Agent0/executor_train/eval_service/app.py @@ -16,32 +16,32 @@ # Set up logging logging.basicConfig( level=logging.ERROR, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler("error_log.txt"), - logging.StreamHandler() - ] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.FileHandler("error_log.txt"), logging.StreamHandler()], ) logger = logging.getLogger(__name__) -def create_app(server_config: ServerConfig, model_config: ModelConfig, tool_config: ToolConfig) -> FastAPI: + +def create_app( + server_config: ServerConfig, model_config: ModelConfig, tool_config: ToolConfig +) -> FastAPI: """ Create and configure the FastAPI application - + Args: server_config: Server configuration object model_config: Model configuration object tool_config: Tool configuration object - + Returns: Configured FastAPI application instance """ app = FastAPI( title="LLM Code Tool Service", description="Large language model code tool calling service compatible with OpenAI API", - version="1.0.0" + version="1.0.0", ) - + # Add CORS middleware to allow cross-origin requests app.add_middleware( CORSMiddleware, @@ -50,18 +50,21 @@ def create_app(server_config: ServerConfig, model_config: ModelConfig, tool_conf allow_methods=["*"], allow_headers=["*"], ) - + # Set debug mode based on environment - if hasattr(server_config, "environment") and server_config.environment == "development": + if ( + hasattr(server_config, "environment") + and server_config.environment == "development" + ): app.debug = True - + # Initialize the model service model_service = ModelService(model_config, tool_config) model_service.load_model() - + # Store service in application state app.state.model_service = model_service - + # Add middleware for global exception handling @app.middleware("http") async def log_exceptions(request: Request, call_next): @@ -71,12 +74,12 @@ async def log_exceptions(request: Request, call_next): error_details = traceback.format_exc() logger.error(f"Unhandled exception: {str(e)}\n{error_details}") raise - + @app.post("/completions") async def chat_completions(request: Request): """ Chat completion API endpoint compatible with OpenAI - + Processes chat messages and returns model-generated responses with tool calling capabilities """ try: @@ -87,58 +90,72 @@ async def chat_completions(request: Request): except Exception as e: error_details = traceback.format_exc() logger.error(f"Error in completions endpoint: {str(e)}\n{error_details}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - + raise HTTPException( + status_code=500, detail=f"Internal server error: {str(e)}" + ) + @app.post("/chat/completions") async def completions(request: Request): """ Chat completion API endpoint compatible with OpenAI - + Processes chat messages and returns model-generated responses with tool calling capabilities """ try: request_body = await request.json() - logger.debug(f"Received chat completions request: {json.dumps(request_body)}") - response = await app.state.model_service.chat_completions_async(request_body) + logger.debug( + f"Received chat completions request: {json.dumps(request_body)}" + ) + response = await app.state.model_service.chat_completions_async( + request_body + ) return response except Exception as e: error_details = traceback.format_exc() - logger.error(f"Error in chat completions endpoint: {str(e)}\n{error_details}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - + logger.error( + f"Error in chat completions endpoint: {str(e)}\n{error_details}" + ) + raise HTTPException( + status_code=500, detail=f"Internal server error: {str(e)}" + ) + @app.get("/health") async def health_check(): """Health check endpoint to verify service availability""" return {"status": "healthy"} - + return app + async def main_async(): # Set up command line argument parsing hf_parser = HfArgumentParser((ServerConfig, ModelConfig, ToolConfig)) - server_config, model_config, tool_config = hf_parser.parse_args_into_dataclasses() + server_config, model_config, tool_config = hf_parser.parse_args_into_dataclasses() tool_config.post_init() - + # Create and run the application app = create_app(server_config, model_config, tool_config) - + # Configure and start the server with enhanced logging config = uvicorn.Config( - app, - host=server_config.host, - port=server_config.port, + app, + host=server_config.host, + port=server_config.port, log_level=server_config.log_level, # Changed from "error" to "debug" for better visibility - ws_max_queue=server_config.ws_max_queue, - workers=server_config.workers*model_config.num_models, + ws_max_queue=server_config.ws_max_queue, + workers=server_config.workers * model_config.num_models, access_log=True, - timeout_keep_alive=server_config.timeout_keep_alive # Added keep-alive timeout setting + timeout_keep_alive=server_config.timeout_keep_alive, # Added keep-alive timeout setting ) server = uvicorn.Server(config) await server.serve() + def main(): import asyncio + asyncio.run(main_async()) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/Agent0/executor_train/eval_service/config.py b/Agent0/executor_train/eval_service/config.py index 64f9a2c..57f27c4 100644 --- a/Agent0/executor_train/eval_service/config.py +++ b/Agent0/executor_train/eval_service/config.py @@ -2,6 +2,7 @@ from typing import Optional, List, Dict, Any, Union from dataclasses import dataclass + @dataclass class ModelConfig: model: str @@ -10,18 +11,24 @@ class ModelConfig: trust_remote_code: bool = True num_models: int = 1 max_model_len: int = 4096 + + @dataclass class ToolConfig: tool_server_url: str = "http://localhost:30150/get_observation" max_turns: int = 5 # max generation turns - truncate_obs_side: str = "left" # "left" or "right", which side to truncate when the observation is too long + truncate_obs_side: str = ( + "left" # "left" or "right", which side to truncate when the observation is too long + ) action_stop_tokens: str = None max_obs_length: int = 512 # maximum length of observation - enable_mtrl: bool=False - mtrl_sep: str=None # "\n<|im_start|>system\n{obs}<|im_end|>\n<|im_start|>assistant\n" - turn_end_token: str="<|im_end|>" - min_turns: int=0 - + enable_mtrl: bool = False + mtrl_sep: str = ( + None # "\n<|im_start|>system\n{obs}<|im_end|>\n<|im_start|>assistant\n" + ) + turn_end_token: str = "<|im_end|>" + min_turns: int = 0 + def post_init(self): """ Post-initialization processing for ToolConfig (will not call automatically) @@ -30,15 +37,20 @@ def post_init(self): if isinstance(self.action_stop_tokens, str): if os.path.exists(self.action_stop_tokens): with open(self.action_stop_tokens, "r") as f: - self.action_stop_tokens = f.read().split(',') + self.action_stop_tokens = f.read().split(",") else: - self.action_stop_tokens = self.action_stop_tokens.split(',') - self.action_stop_tokens = [token.strip('\n ') for token in self.action_stop_tokens] - self.action_stop_tokens = [token for token in self.action_stop_tokens if token] + self.action_stop_tokens = self.action_stop_tokens.split(",") + self.action_stop_tokens = [ + token.strip("\n ") for token in self.action_stop_tokens + ] + self.action_stop_tokens = [ + token for token in self.action_stop_tokens if token + ] else: self.action_stop_tokens = None print(f"using action_stop_tokens: {self.action_stop_tokens}") + @dataclass class ServerConfig: host: str = "0.0.0.0" @@ -46,4 +58,4 @@ class ServerConfig: workers: int = 32 ws_max_queue: int = 1000 log_level: str = "error" - timeout_keep_alive: int = 60 \ No newline at end of file + timeout_keep_alive: int = 60 diff --git a/Agent0/executor_train/eval_service/model_service.py b/Agent0/executor_train/eval_service/model_service.py index 1d35cb1..74853a1 100644 --- a/Agent0/executor_train/eval_service/model_service.py +++ b/Agent0/executor_train/eval_service/model_service.py @@ -18,9 +18,10 @@ # other C0 control characters except common whitespace). CONTROL_CHAR_RE = re.compile( # this matches U+0000 through U+001F, excluding tab(09), LF(0A), CR(0D) - r'[\x00-\x08\x0B\x0C\x0E-\x1F]' + r"[\x00-\x08\x0B\x0C\x0E-\x1F]" ) + def sanitize_request(obj: Any) -> Any: """ Recursively walk through obj and: @@ -30,18 +31,21 @@ def sanitize_request(obj: Any) -> Any: - Leave other types untouched """ if isinstance(obj, dict): - return {sanitize_request(key): sanitize_request(val) for key, val in obj.items()} + return { + sanitize_request(key): sanitize_request(val) for key, val in obj.items() + } elif isinstance(obj, (list, tuple)): return type(obj)(sanitize_request(item) for item in obj) elif isinstance(obj, str): # strip NUL (\x00) and other C0 control chars - return CONTROL_CHAR_RE.sub('', obj) + return CONTROL_CHAR_RE.sub("", obj) else: return obj - + + class ModelService: """verl-tool model inference service""" - + def __init__(self, model_config: ModelConfig, tool_config: ToolConfig): """initialize model service""" self.model_config = model_config @@ -52,10 +56,18 @@ def __init__(self, model_config: ModelConfig, tool_config: ToolConfig): self.encode_lock = asyncio.Lock() if self.tool_config.mtrl_sep is None: messages = [{"role": "system", "content": "{obs}"}] - self.tool_config.mtrl_sep = "\n" + self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + self.tool_config.mtrl_sep = "\n" + self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) # self.tool_config.mtrl_sep = self.tool_config.mtrl_sep.replace("system", "user") - - def call_tool_server(self, trajectory_ids: List[str], actions: List[str], finish: List[bool], **kwargs: Dict[str, List[Any]]) -> Dict[str, Any]: + + def call_tool_server( + self, + trajectory_ids: List[str], + actions: List[str], + finish: List[bool], + **kwargs: Dict[str, List[Any]], + ) -> Dict[str, Any]: """querying the tool server for the observation and done flag""" server_url = self.tool_config.tool_server_url # prepare payload @@ -63,23 +75,32 @@ def call_tool_server(self, trajectory_ids: List[str], actions: List[str], finish "trajectory_ids": trajectory_ids, "actions": actions, "finish": finish, - **kwargs + **kwargs, } try: data = sanitize_request(data) response = requests.post(server_url, json=data) response.raise_for_status() result = response.json() - return result + return result except Exception as e: print(f"Error calling tool server: {str(e)}") return { - "observations": [f"Error calling tool server: {str(e)}" for _ in range(len(trajectory_ids))], + "observations": [ + f"Error calling tool server: {str(e)}" + for _ in range(len(trajectory_ids)) + ], "dones": [True for _ in range(len(trajectory_ids))], - "valids": [False for _ in range(len(trajectory_ids))] + "valids": [False for _ in range(len(trajectory_ids))], } - - async def call_tool_server_async(self, trajectory_ids: List[str], actions: List[str], finish: List[bool], **kwargs: Dict[str, List[Any]]) -> Dict[str, Any]: + + async def call_tool_server_async( + self, + trajectory_ids: List[str], + actions: List[str], + finish: List[bool], + **kwargs: Dict[str, List[Any]], + ) -> Dict[str, Any]: """querying the tool server for the observation and done flag using aiohttp""" server_url = self.tool_config.tool_server_url # prepare payload @@ -87,13 +108,13 @@ async def call_tool_server_async(self, trajectory_ids: List[str], actions: List[ "trajectory_ids": trajectory_ids, "actions": actions, "finish": finish, - **kwargs + **kwargs, } - + # Create aiohttp session if it doesn't exist if self.session is None: self.session = aiohttp.ClientSession() - + try: data = sanitize_request(data) async with self.session.post(server_url, json=data) as response: @@ -103,78 +124,110 @@ async def call_tool_server_async(self, trajectory_ids: List[str], actions: List[ except Exception as e: print(f"Error calling tool server: {str(e)}") return { - "observations": [f"Error calling tool server: {str(e)}" for _ in range(len(trajectory_ids))], + "observations": [ + f"Error calling tool server: {str(e)}" + for _ in range(len(trajectory_ids)) + ], "dones": [True for _ in range(len(trajectory_ids))], - "valids": [False for _ in range(len(trajectory_ids))] + "valids": [False for _ in range(len(trajectory_ids))], } - - async def post_process_observations(self, next_obs: List[str], dones: List[bool], valid_action: List[bool], finishs: List[bool]): + + async def post_process_observations( + self, + next_obs: List[str], + dones: List[bool], + valid_action: List[bool], + finishs: List[bool], + ): """Process observations using the tokenizer with proper async locks""" next_obs = [obs if not done else "" for obs, done in zip(next_obs, dones)] async with self.encode_lock: mtrl_sep = self.tool_config.mtrl_sep - if self.tool_config.truncate_obs_side == 'left': + if self.tool_config.truncate_obs_side == "left": next_obs_ids = self.tokenizer( next_obs, - padding='longest', - return_tensors='pt', + padding="longest", + return_tensors="pt", add_special_tokens=False, # Prevents adding special tokens - padding_side='left', - )['input_ids'].to(torch.int64) + padding_side="left", + )["input_ids"].to(torch.int64) if next_obs_ids.shape[1] > self.tool_config.max_obs_length: - print(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.tool_config.max_obs_length}") - next_obs_ids = next_obs_ids[:, -self.tool_config.max_obs_length:] - elif self.tool_config.truncate_obs_side == 'right': + print( + f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.tool_config.max_obs_length}" + ) + next_obs_ids = next_obs_ids[:, -self.tool_config.max_obs_length :] + elif self.tool_config.truncate_obs_side == "right": next_obs_ids = self.tokenizer( next_obs, - padding='longest', - return_tensors='pt', + padding="longest", + return_tensors="pt", add_special_tokens=False, # Prevents adding special tokens - padding_side='right', - )['input_ids'].to(torch.int64) + padding_side="right", + )["input_ids"].to(torch.int64) if next_obs_ids.shape[1] > self.tool_config.max_obs_length: - print(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.tool_config.max_obs_length}") - next_obs_ids = next_obs_ids[:, :self.tool_config.max_obs_length] + print( + f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.tool_config.max_obs_length}" + ) + next_obs_ids = next_obs_ids[:, : self.tool_config.max_obs_length] else: - raise ValueError(f"Invalid truncate_obs_side: {self.tool_config.truncate_obs_side}") + raise ValueError( + f"Invalid truncate_obs_side: {self.tool_config.truncate_obs_side}" + ) if self.tool_config.enable_mtrl: next_obs = self.tokenizer.batch_decode( - next_obs_ids, - skip_special_tokens=True + next_obs_ids, skip_special_tokens=True ) processed_next_obs = [] for i in range(len(next_obs)): if finishs[i] or dones[i]: # do action is false - assert next_obs[i] == "", f"next_obs should be empty when finishs is True, but got {next_obs[i]}" + assert ( + next_obs[i] == "" + ), f"next_obs should be empty when finishs is True, but got {next_obs[i]}" processed_next_obs.append("") elif valid_action[i]: processed_next_obs.append(mtrl_sep.format(obs=next_obs[i])) else: - processed_next_obs.append(mtrl_sep.format(obs="Your action is not valid, please check the format and try again." + next_obs[i])) + processed_next_obs.append( + mtrl_sep.format( + obs="Your action is not valid, please check the format and try again." + + next_obs[i] + ) + ) next_obs = processed_next_obs next_obs_ids = self.tokenizer( next_obs, - padding='longest', - return_tensors='pt', + padding="longest", + return_tensors="pt", add_special_tokens=False, # Prevents adding special tokens - )['input_ids'].to(torch.int64) + )["input_ids"].to(torch.int64) next_obs = self.tokenizer.batch_decode( next_obs_ids, skip_special_tokens=True, ) return next_obs - - async def _postprocess_responses(self, outputs: torch.Tensor, action_step: int) -> torch.Tensor: + + async def _postprocess_responses( + self, outputs: torch.Tensor, action_step: int + ) -> torch.Tensor: """Process responses to stop at python operation or answer operation.""" - active_responses = [outputs.choices[i].text for i in range(len(outputs.choices))] - active_finish_reasons = [outputs.choices[i].finish_reason for i in range(len(outputs.choices))] - + active_responses = [ + outputs.choices[i].text for i in range(len(outputs.choices)) + ] + active_finish_reasons = [ + outputs.choices[i].finish_reason for i in range(len(outputs.choices)) + ] + finishes = [] for i in range(len(active_responses)): finish = True - if active_finish_reasons[i] == "stop" and outputs.choices[i].stop_reason is not None: - active_responses[i] = active_responses[i] + outputs.choices[i].stop_reason + if ( + active_finish_reasons[i] == "stop" + and outputs.choices[i].stop_reason is not None + ): + active_responses[i] = ( + active_responses[i] + outputs.choices[i].stop_reason + ) if self.tool_config.enable_mtrl: active_responses[i] += self.tool_config.turn_end_token finish = False @@ -187,32 +240,53 @@ async def _postprocess_responses(self, outputs: torch.Tensor, action_step: int) active_responses[i] += self.tool_config.turn_end_token finishes.append(finish) return active_responses, finishes, active_finish_reasons - + def load_model(self): """load the model using VLLM backend""" print(f"Loading Model using VLLM: {self.model_config.model}...") # start a VLLM server using vllm.serve - vllm_args = [f"--{k.replace('_', '-')}" for k in self.model_config.__dict__.keys() if k not in ["model", "api_key", "num_models", "host", "port"]] + vllm_args = [ + f"--{k.replace('_', '-')}" + for k in self.model_config.__dict__.keys() + if k not in ["model", "api_key", "num_models", "host", "port"] + ] vllm_args = [] for k, v in self.model_config.__dict__.items(): if k not in ["model", "api_key", "num_models", "host", "port"]: - vllm_args.append(f"--{k.replace('_', '-')}") - if not isinstance(v, bool): - vllm_args.append(str(v)) - + vllm_args.append(f"--{k.replace('_', '-')}") + if not isinstance(v, bool): + vllm_args.append(str(v)) + host = "0.0.0.0" num_models = self.model_config.num_models ports = random.sample(range(8000, 9000), num_models) self.vllm_processes = [] - gpu_ids = os.environ.get("CUDA_VISIBLE_DEVICES", ",".join([str(i) for i in range(torch.cuda.device_count())])).split(",") + gpu_ids = os.environ.get( + "CUDA_VISIBLE_DEVICES", + ",".join([str(i) for i in range(torch.cuda.device_count())]), + ).split(",") tensor_parallel_size = self.model_config.tensor_parallel_size - gpu_ids_per_model = [gpu_ids[i:i+tensor_parallel_size] for i in range(0, len(gpu_ids), tensor_parallel_size)] - assert len(gpu_ids) >= num_models * tensor_parallel_size, f"Not enough GPUs available: {len(gpu_ids)} < {num_models * tensor_parallel_size}" + gpu_ids_per_model = [ + gpu_ids[i : i + tensor_parallel_size] + for i in range(0, len(gpu_ids), tensor_parallel_size) + ] + assert ( + len(gpu_ids) >= num_models * tensor_parallel_size + ), f"Not enough GPUs available: {len(gpu_ids)} < {num_models * tensor_parallel_size}" for i in range(num_models): cmd = [ - "vllm", "serve", self.model_config.model, "--api-key", "token-abc123", - "--host", host, "--port", str(ports[i]), - "--disable-uvicorn-access-log", "--disable-log-stats", "--disable-log-requests" + "vllm", + "serve", + self.model_config.model, + "--api-key", + "token-abc123", + "--host", + host, + "--port", + str(ports[i]), + "--disable-uvicorn-access-log", + "--disable-log-stats", + "--disable-log-requests", ] + vllm_args env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_ids_per_model[i]) @@ -220,9 +294,12 @@ def load_model(self): vllm_process = subprocess.Popen(cmd, env=env) self.vllm_processes.append(vllm_process) self.clients = [ - openai.Client(api_key="token-abc123", base_url=f"http://{host}:{ports[i]}/v1") for i in range(num_models) + openai.Client( + api_key="token-abc123", base_url=f"http://{host}:{ports[i]}/v1" + ) + for i in range(num_models) ] - + # Wait for the service to start (poll the health endpoint) max_retries = 60 retry_interval = 10 @@ -239,66 +316,75 @@ def load_model(self): # print(f"vLLM instance model-{j} at {host}:{ports[j]} is not ready yet: {str(e)}") continue if all(vllm_model_status): - print(f"โœ… vLLM service started successfully with model: {self.model_config.model}") - return + print( + f"โœ… vLLM service started successfully with model: {self.model_config.model}" + ) + return else: time.sleep(retry_interval) - + # If we get here, the service failed to start print("Failed to start one or more vLLM services. Check vLLM logs.") for process in self.vllm_processes: stderr = process.stderr.read() print(f"vLLM stderr: {stderr}") process.terminate() - + raise RuntimeError("Failed to start vLLM services") - - async def send_request(self, client, prompts: List[str], model:str, sampling_params: dict) -> str: + + async def send_request( + self, client, prompts: List[str], model: str, sampling_params: dict + ) -> str: # Send the request using the client sampling_params = sampling_params.copy() # Use the async encode method to get tokens async with self.encode_lock: prompt_lens = [len(self.tokenizer.encode(prompt)) for prompt in prompts] max_prompt_tokens = max(prompt_lens) - - sampling_params['max_tokens'] = min(max(self.model_config.max_model_len - max_prompt_tokens, 0), sampling_params['max_tokens']) + + sampling_params["max_tokens"] = min( + max(self.model_config.max_model_len - max_prompt_tokens, 0), + sampling_params["max_tokens"], + ) # print(f"Sending request to {client.base_url} with sampling params: {sampling_params}") - + # Run the API call in an executor to not block the event loop response = await asyncio.get_event_loop().run_in_executor( None, lambda: client.completions.create( - model=model, - prompt=prompts, - echo=False, - stream=False, - **sampling_params - ) + model=model, prompt=prompts, echo=False, stream=False, **sampling_params + ), ) return response - - async def generate_with_tools(self, prompts: List[str], sampling_params: dict) -> Tuple[List[str], List[str]]: + + async def generate_with_tools( + self, prompts: List[str], sampling_params: dict + ) -> Tuple[List[str], List[str]]: """ Generate text with tool calls in a multi-turn loop. - + Args: prompts: Initial prompts for generation sampling_params: Sampling parameters for the model - + Returns: Tuple of (full_responses, finish_reasons) """ - client = random.choice(self.clients) # ensure the same trajectory uses the same client for prefix caching - assert sampling_params.get("n", 1) <= 1, "n > 1 is not supported yet for tool generation" + client = random.choice( + self.clients + ) # ensure the same trajectory uses the same client for prefix caching + assert ( + sampling_params.get("n", 1) <= 1 + ), "n > 1 is not supported yet for tool generation" contexts = prompts final_responses = ["" for _ in range(len(prompts))] traj_ids = [str(uuid.uuid4()) for _ in range(len(prompts))] active_masks = [True for _ in range(len(prompts))] finish_reasons = [None for _ in range(len(prompts))] model = self.model_config.model - + # keep trying to generate the response until reached the tool-calling limit - for action_step in range(self.tool_config.max_turns+1): + for action_step in range(self.tool_config.max_turns + 1): # print(f"Action step: {action_step}/{self.tool_config.max_turns}") if action_step == self.tool_config.max_turns: # last turn, don't stop by action stop tokens @@ -306,43 +392,49 @@ async def generate_with_tools(self, prompts: List[str], sampling_params: dict) - for action_stop_token in self.tool_config.action_stop_tokens: if action_stop_token in sampling_params["stop"]: sampling_params["stop"].remove(action_stop_token) - - active_traj_ids = [traj_ids[i] for i in range(len(traj_ids)) if active_masks[i]] - active_contexts = [contexts[i] for i in range(len(contexts)) if active_masks[i]] + + active_traj_ids = [ + traj_ids[i] for i in range(len(traj_ids)) if active_masks[i] + ] + active_contexts = [ + contexts[i] for i in range(len(contexts)) if active_masks[i] + ] if len(active_contexts) == 0: break - + # send request asynchronously outputs = await self.send_request( - client, - active_contexts, - model, - sampling_params + client, active_contexts, model, sampling_params ) - active_responses, finishes, active_finish_reasons = await self._postprocess_responses(outputs, action_step) - + active_responses, finishes, active_finish_reasons = ( + await self._postprocess_responses(outputs, action_step) + ) + # Use async tool server call if possible - if hasattr(self, 'call_tool_server_async'): + if hasattr(self, "call_tool_server_async"): tool_responses = await self.call_tool_server_async( - active_traj_ids, - active_responses, - finishes + active_traj_ids, active_responses, finishes ) else: # Fallback to sync version but run in executor tool_responses = await asyncio.get_event_loop().run_in_executor( - None, + None, self.call_tool_server, active_traj_ids, active_responses, - finishes + finishes, ) - + # print(f"Active observations (preprocess): {tool_responses['observations']}") - observations = await self.post_process_observations(tool_responses["observations"], tool_responses["dones"], tool_responses["valids"], finishes) + observations = await self.post_process_observations( + tool_responses["observations"], + tool_responses["dones"], + tool_responses["valids"], + finishes, + ) dones = tool_responses["dones"] valids = tool_responses["valids"] - + # print(f"Active step: {action_step}") # print(f"Active responses: {active_responses}") # print(f"Active observations: {observations}") @@ -354,51 +446,63 @@ async def generate_with_tools(self, prompts: List[str], sampling_params: dict) - active_idx = 0 for i in range(len(contexts)): if active_masks[i]: - contexts[i] += active_responses[active_idx] + observations[active_idx] - final_responses[i] += active_responses[active_idx] + observations[active_idx] + contexts[i] += ( + active_responses[active_idx] + observations[active_idx] + ) + final_responses[i] += ( + active_responses[active_idx] + observations[active_idx] + ) finish_reasons[i] = active_finish_reasons[active_idx] active_masks[i] = not dones[active_idx] active_idx += 1 - + return final_responses, finish_reasons - + async def chat_completions_async(self, body: Dict[str, Any]) -> Dict[str, Any]: """process API request and generate response""" # print(f"Received request: {body}") - + if "messages" not in body or not body["messages"]: raise ValueError("No messages found in the request.") - if not 'user' in [message["role"] for message in body["messages"]]: + if not "user" in [message["role"] for message in body["messages"]]: raise ValueError("No user message found in the request.") - - assert body["model"] == self.model_config.model, f"model mismatch: {body['model']} != {self.model_config.model}" - + + assert ( + body["model"] == self.model_config.model + ), f"model mismatch: {body['model']} != {self.model_config.model}" + async with self.encode_lock: - prompt = self.tokenizer.apply_chat_template(body['messages'], - add_generation_prompt=True, - tokenize=False) - if body.get('n', 1) > 1: + prompt = self.tokenizer.apply_chat_template( + body["messages"], add_generation_prompt=True, tokenize=False + ) + if body.get("n", 1) > 1: prompts = [prompt for _ in range(body["n"])] else: prompts = [prompt] sampling_params = { "temperature": body.get("temperature", 1.0), - "max_tokens": body.get("max_tokens", body.get("max_completion_tokens", 512)), + "max_tokens": body.get( + "max_tokens", body.get("max_completion_tokens", 512) + ), "top_p": body.get("top_p", 1.0), - "stop": list(set(body.get("stop", []) + self.tool_config.action_stop_tokens)), + "stop": list( + set(body.get("stop", []) + self.tool_config.action_stop_tokens) + ), } # print(f"Sampling params: {sampling_params}") - all_responses, finish_reasons = await self.generate_with_tools(prompts, sampling_params) - + all_responses, finish_reasons = await self.generate_with_tools( + prompts, sampling_params + ) + async with self.encode_lock: prompt_tokens = len(self.tokenizer.encode(prompt)) completion_tokens = 0 for response in all_responses: completion_tokens += len(self.tokenizer.encode(response)) total_tokens = prompt_tokens + completion_tokens - + # format the response into OpenAI-compliant format return { "id": f"chatcmpl-{str(uuid.uuid4())}", @@ -412,49 +516,58 @@ async def chat_completions_async(self, body: Dict[str, Any]) -> Dict[str, Any]: "role": "assistant", "content": all_responses[i], }, - "finish_reason": finish_reasons[i] - } for i in range(len(all_responses)) + "finish_reason": finish_reasons[i], + } + for i in range(len(all_responses)) ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, - "total_tokens": total_tokens - } + "total_tokens": total_tokens, + }, } - + def chat_completions(self, body: Dict[str, Any]) -> Dict[str, Any]: """Synchronous wrapper for chat_completions""" return asyncio.run(self.chat_completions_async(body)) - + async def completions_async(self, body: Dict[str, Any]) -> Dict[str, Any]: """process API request and generate response async""" # print(f"Received request: {body}") - if 'prompt' not in body: + if "prompt" not in body: raise ValueError("No prompt found in the request.") - assert body["model"] == self.model_config.model, f"model mismatch: {body['model']} != {self.model_config.model}" - prompt = body['prompt'] + assert ( + body["model"] == self.model_config.model + ), f"model mismatch: {body['model']} != {self.model_config.model}" + prompt = body["prompt"] - if body.get('n', 1) > 1: + if body.get("n", 1) > 1: prompts = [prompt for _ in range(body["n"])] else: prompts = [prompt] sampling_params = { "temperature": body.get("temperature", 1.0), - "max_tokens": body.get("max_tokens", body.get("max_completion_tokens", 512)), + "max_tokens": body.get( + "max_tokens", body.get("max_completion_tokens", 512) + ), "top_p": body.get("top_p", 1.0), - "stop": list(set(body.get("stop", []) + self.tool_config.action_stop_tokens)), + "stop": list( + set(body.get("stop", []) + self.tool_config.action_stop_tokens) + ), } - all_responses, finish_reasons = await self.generate_with_tools(prompts, sampling_params) - + all_responses, finish_reasons = await self.generate_with_tools( + prompts, sampling_params + ) + async with self.encode_lock: prompt_tokens = len(self.tokenizer.encode(prompt)) completion_tokens = 0 for response in all_responses: completion_tokens += len(self.tokenizer.encode(response)) total_tokens = prompt_tokens + completion_tokens - + # format the response into OpenAI-compliant format return { "id": f"chatcmpl-{str(uuid.uuid4())}", @@ -465,27 +578,28 @@ async def completions_async(self, body: Dict[str, Any]) -> Dict[str, Any]: { "index": i, "text": all_responses[i], - "finish_reason": finish_reasons[i] - } for i in range(len(all_responses)) + "finish_reason": finish_reasons[i], + } + for i in range(len(all_responses)) ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, - "total_tokens": total_tokens - } + "total_tokens": total_tokens, + }, } - + def completions(self, body: Dict[str, Any]) -> Dict[str, Any]: """Synchronous wrapper for completions_async""" return asyncio.run(self.completions_async(body)) - + async def close(self): """Close any resources (like HTTP sessions and processes) when shutting down""" # Close HTTP session if self.session: await self.session.close() self.session = None - + # Terminate all VLLM processes for process in self.vllm_processes: if process: @@ -494,10 +608,10 @@ async def close(self): process.wait(timeout=5) except subprocess.TimeoutExpired: process.kill() - + self.vllm_processes = [] self.clients = [] - + def __del__(self): """Destructor to ensure resources are cleaned up""" try: diff --git a/Agent0/executor_train/eval_service/test/test_api.py b/Agent0/executor_train/eval_service/test/test_api.py index fff6835..9108a14 100644 --- a/Agent0/executor_train/eval_service/test/test_api.py +++ b/Agent0/executor_train/eval_service/test/test_api.py @@ -2,20 +2,23 @@ from openai import OpenAI from transformers import AutoTokenizer + def main( model_name: str, base_url: str, test_task: str = "math", - test_type: str = "chat_completion", # or "completion" + test_type: str = "chat_completion", # or "completion" api_key: str = "sk-proj-1234567890", temperature: float = 0.0, max_tokens: int = 2048, top_p: float = 1.0, n: int = 1, ): - client = OpenAI(api_key=api_key, base_url=base_url) # Replace with your local server address + client = OpenAI( + api_key=api_key, base_url=base_url + ) # Replace with your local server address tokenizer = AutoTokenizer.from_pretrained(model_name) - + # get test_task if test_task == "math": print("Testing math task...") @@ -23,33 +26,25 @@ def main( math_problem = "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$" chat_messages = [ - { - "role": "system", - "content": system_prompt - }, - { - "role": "user", - "content": math_problem - } + {"role": "system", "content": system_prompt}, + {"role": "user", "content": math_problem}, ] - prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True) + prompt = tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) else: raise ValueError(f"Unknown test task: {test_task}") - - + if test_type == "chat_completion": - print(f"Testing {test_task} with {test_type} on model {model_name} at {base_url}", flush=True) + print( + f"Testing {test_task} with {test_type} on model {model_name} at {base_url}", + flush=True, + ) completion = client.chat.completions.create( model=model_name, messages=[ - { - "role": "system", - "content": system_prompt - }, - { - "role": "user", - "content": math_problem - } + {"role": "system", "content": system_prompt}, + {"role": "user", "content": math_problem}, ], temperature=temperature, max_tokens=max_tokens, @@ -58,18 +53,17 @@ def main( ) print(completion.choices[0].message.content) elif test_type == "completion": - print(f"Testing {test_task} with {test_type} on model {model_name} at {base_url}", flush=True) + print( + f"Testing {test_task} with {test_type} on model {model_name} at {base_url}", + flush=True, + ) chat_messages = [ - { - "role": "system", - "content": system_prompt - }, - { - "role": "user", - "content": math_problem - } + {"role": "system", "content": system_prompt}, + {"role": "user", "content": math_problem}, ] - prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True) + prompt = tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) completion = client.completions.create( model=model_name, prompt=prompt, @@ -82,8 +76,10 @@ def main( else: raise ValueError(f"Unknown test type: {test_type}") + if __name__ == "__main__": import fire + fire.Fire(main) """ diff --git a/Agent0/executor_train/eval_service/test/test_api_mp.py b/Agent0/executor_train/eval_service/test/test_api_mp.py index 313b571..4381512 100644 --- a/Agent0/executor_train/eval_service/test/test_api_mp.py +++ b/Agent0/executor_train/eval_service/test/test_api_mp.py @@ -11,89 +11,95 @@ # Different variations of the math problem to simulate diverse requests math_problems = [ math_problem, - math_problem.replace("9-kilometer", "10-kilometer").replace("4 hours", "5 hours").replace("2 hours and 24 minutes", "3 hours"), - math_problem.replace("9-kilometer", "8-kilometer").replace("4 hours", "3 hours").replace("2 hours and 24 minutes", "1 hour and 48 minutes"), + math_problem.replace("9-kilometer", "10-kilometer") + .replace("4 hours", "5 hours") + .replace("2 hours and 24 minutes", "3 hours"), + math_problem.replace("9-kilometer", "8-kilometer") + .replace("4 hours", "3 hours") + .replace("2 hours and 24 minutes", "1 hour and 48 minutes"), math_problem.replace("s+\\frac{1}{2}", "s+\\frac{2}{3}"), - math_problem.replace("s+\\frac{1}{2}", "s+1") + math_problem.replace("s+\\frac{1}{2}", "s+1"), ] + async def send_request(client, problem_text, request_id): """Send a single request and measure the time it takes""" start_time = time.time() print(f"Starting request {request_id}...") - + try: completion = await client.chat.completions.create( model="GAIR/ToRL-1.5B", messages=[ - { - "role": "system", - "content": system_prompt - }, - { - "role": "user", - "content": problem_text - } + {"role": "system", "content": system_prompt}, + {"role": "user", "content": problem_text}, ], temperature=0, max_tokens=2048, top_p=1, n=1, ) - + end_time = time.time() print(f"Request {request_id} completed in {end_time - start_time:.2f} seconds") - + # Print a shortened version of the response for verification response_content = completion.choices[0].message.content print(f"Request {request_id} response (truncated): {response_content}...\n") - + return { "request_id": request_id, "duration": end_time - start_time, - "response": response_content + "response": response_content, } except Exception as e: end_time = time.time() - print(f"Request {request_id} failed after {end_time - start_time:.2f} seconds: {str(e)}") + print( + f"Request {request_id} failed after {end_time - start_time:.2f} seconds: {str(e)}" + ) return { "request_id": request_id, "duration": end_time - start_time, - "error": str(e) + "error": str(e), } + async def run_concurrent_test(num_concurrent=5, num_total=10): """Run multiple concurrent requests to test server performance""" client = AsyncOpenAI(api_key="sk-proj-1234567890", base_url="http://0.0.0.0:5000") - - print(f"Starting concurrent test with {num_concurrent} concurrent requests, {num_total} total requests") + + print( + f"Starting concurrent test with {num_concurrent} concurrent requests, {num_total} total requests" + ) start_time = time.time() - + # Create tasks for all requests tasks = [] for i in range(num_total): problem = math_problems[i % len(math_problems)] - tasks.append(send_request(client, problem, i+1)) - + tasks.append(send_request(client, problem, i + 1)) + # Run requests in batches of num_concurrent results = [] for i in range(0, len(tasks), num_concurrent): - batch = tasks[i:i+num_concurrent] + batch = tasks[i : i + num_concurrent] batch_results = await asyncio.gather(*batch) results.extend(batch_results) - + end_time = time.time() total_duration = end_time - start_time - + # Calculate statistics successful_requests = [r for r in results if "error" not in r] failed_requests = [r for r in results if "error" in r] - + if successful_requests: - avg_request_time = sum(r["duration"] for r in successful_requests) / len(successful_requests) + avg_request_time = sum(r["duration"] for r in successful_requests) / len( + successful_requests + ) else: avg_request_time = 0 - + # Print summary print("\n===== TEST RESULTS =====") print(f"Total test duration: {total_duration:.2f} seconds") @@ -102,46 +108,51 @@ async def run_concurrent_test(num_concurrent=5, num_total=10): print(f"Failed requests: {len(failed_requests)}") print(f"Average request time: {avg_request_time:.2f} seconds") print(f"Requests per second: {num_total / total_duration:.2f}") - + if failed_requests: print("\nFailed requests:") for req in failed_requests: print(f" Request {req['request_id']}: {req['error']}") + async def sequential_test_for_comparison(num_requests=5): """Run sequential requests as a baseline for comparison""" client = AsyncOpenAI(api_key="sk-proj-1234567890", base_url="http://0.0.0.0:5000") - + print(f"\nStarting sequential test with {num_requests} requests for comparison") start_time = time.time() - + results = [] for i in range(num_requests): problem = math_problems[i % len(math_problems)] result = await send_request(client, problem, f"seq-{i+1}") results.append(result) - + end_time = time.time() total_duration = end_time - start_time - + # Calculate statistics successful_requests = [r for r in results if "error" not in r] - + if successful_requests: - avg_request_time = sum(r["duration"] for r in successful_requests) / len(successful_requests) + avg_request_time = sum(r["duration"] for r in successful_requests) / len( + successful_requests + ) else: avg_request_time = 0 - + # Print summary print("\n===== SEQUENTIAL TEST RESULTS =====") print(f"Total test duration: {total_duration:.2f} seconds") print(f"Average request time: {avg_request_time:.2f} seconds") print(f"Requests per second: {num_requests / total_duration:.2f}") + async def main(): # Run both tests await run_concurrent_test(num_concurrent=3, num_total=6) await sequential_test_for_comparison(num_requests=3) + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/Agent0/executor_train/scripts/visualize_entropy.py b/Agent0/executor_train/scripts/visualize_entropy.py index da3cb31..b8c4ad2 100644 --- a/Agent0/executor_train/scripts/visualize_entropy.py +++ b/Agent0/executor_train/scripts/visualize_entropy.py @@ -10,10 +10,13 @@ from tqdm import tqdm from collections import defaultdict -def plot_entropy_bar(entropy, labels, title="Token Entropy", save_path="entropy_plot.png"): + +def plot_entropy_bar( + entropy, labels, title="Token Entropy", save_path="entropy_plot.png" +): """ Plot the token entropy with color highlighting based on masks and background shading. - + Args: entropy (list): List of entropy values corresponding to each token. labels (List[str]): List of labels for the tokens, e.g., "prompt", "action" or "obs". @@ -22,25 +25,39 @@ def plot_entropy_bar(entropy, labels, title="Token Entropy", save_path="entropy_ """ # Color map for distinguishing between the parts color_map = {"prompt": "green", "action": "red", "obs": "blue"} - + plt.figure(figsize=(15 + len(entropy) * 0.01, 4)) clipped_entropy = np.clip(entropy, 0, 10) token_indices = np.arange(len(entropy)) # Initialize to hold color and label settings token_colors = [color_map.get(label, "gray") for label in labels] - alpha_values = [0.6 if label == "prompt" else 0.9 for label in labels] # Lighter for prompts, darker for actions and obs - + alpha_values = [ + 0.6 if label == "prompt" else 0.9 for label in labels + ] # Lighter for prompts, darker for actions and obs + # Plot background color for each section last_idx = 0 last_label = labels[0] for i in range(len(labels)): if labels[i] != last_label: - plt.axvspan(last_idx, i - 1, color=color_map[last_label], alpha=0.1, label=f"{last_label.capitalize()} Background") + plt.axvspan( + last_idx, + i - 1, + color=color_map[last_label], + alpha=0.1, + label=f"{last_label.capitalize()} Background", + ) last_idx = i last_label = labels[i] - plt.axvspan(last_idx, len(labels) - 1, color=color_map[last_label], alpha=0.1, label=f"{last_label.capitalize()} Background") - + plt.axvspan( + last_idx, + len(labels) - 1, + color=color_map[last_label], + alpha=0.1, + label=f"{last_label.capitalize()} Background", + ) + # Bar plot with clear separation for each token part for i in range(len(entropy)): plt.bar(i, clipped_entropy[i], color=token_colors[i], alpha=alpha_values[i]) @@ -51,20 +68,26 @@ def plot_entropy_bar(entropy, labels, title="Token Entropy", save_path="entropy_ plt.tight_layout() # Adding a legend to make distinction clear - plt.legend(handles=[plt.Line2D([0], [0], color=color_map["prompt"], lw=4), - plt.Line2D([0], [0], color=color_map["action"], lw=4), - plt.Line2D([0], [0], color=color_map["obs"], lw=4)], - labels=["Prompt", "Action", "Obs"], title="Token Type") - + plt.legend( + handles=[ + plt.Line2D([0], [0], color=color_map["prompt"], lw=4), + plt.Line2D([0], [0], color=color_map["action"], lw=4), + plt.Line2D([0], [0], color=color_map["obs"], lw=4), + ], + labels=["Prompt", "Action", "Obs"], + title="Token Type", + ) + # Grid lines for better readability - plt.grid(True, axis='y', linestyle='--', alpha=0.5) - + plt.grid(True, axis="y", linestyle="--", alpha=0.5) + plt.savefig(save_path, dpi=300) return save_path + def main( - file_path:str, - model_name:str = "Qwen/Qwen2.5-Math-1.5B", + file_path: str, + model_name: str = "Qwen/Qwen2.5-Math-1.5B", batch_size=4, vis_dir: str = "entropy_vis", ): @@ -74,42 +97,65 @@ def main( pad_token_id = tokenizer.pad_token_id # Read the JSON file - with open(file_path, 'r') as f: + with open(file_path, "r") as f: data = json.load(f) data = datasets.Dataset.from_list(data) - data = data.filter(lambda x: x['num_turn'] > 0, num_proc=8, desc="Filtering dataset with num_turn > 0") + data = data.filter( + lambda x: x["num_turn"] > 0, + num_proc=8, + desc="Filtering dataset with num_turn > 0", + ) print(data) - full_inputs = [x['prompt'] + x['response'] for x in data] - full_inputs_with_mask = [x['prompt'] + x['response_with_loss_mask'] for x in data] + full_inputs = [x["prompt"] + x["response"] for x in data] + full_inputs_with_mask = [x["prompt"] + x["response_with_loss_mask"] for x in data] # Tokenize the inputs vis_dir = Path(vis_dir) vis_dir.mkdir(parents=True, exist_ok=True) vis_paths = [] - entropy_avgs = [] # list of sum entropy values, [0] for prompt, [1] for action 1, [2] for obs 1, [3] for action 2, [4] for obs 2, ... - for i in tqdm(range(0, len(full_inputs), batch_size), desc="Processing batches", total=len(full_inputs) // batch_size): - prompts = data['prompt'][i:i + batch_size] - batch = full_inputs[i:i + batch_size] - batch_with_mask = full_inputs_with_mask[i:i + batch_size] - inputs = tokenizer(batch, return_tensors='pt', padding="longest").to(model.device) - inputs_with_mask = tokenizer(batch_with_mask, return_tensors='pt', padding="longest").to(model.device) - attention_mask = inputs['attention_mask'] + entropy_avgs = ( + [] + ) # list of sum entropy values, [0] for prompt, [1] for action 1, [2] for obs 1, [3] for action 2, [4] for obs 2, ... + for i in tqdm( + range(0, len(full_inputs), batch_size), + desc="Processing batches", + total=len(full_inputs) // batch_size, + ): + prompts = data["prompt"][i : i + batch_size] + batch = full_inputs[i : i + batch_size] + batch_with_mask = full_inputs_with_mask[i : i + batch_size] + inputs = tokenizer(batch, return_tensors="pt", padding="longest").to( + model.device + ) + inputs_with_mask = tokenizer( + batch_with_mask, return_tensors="pt", padding="longest" + ).to(model.device) + attention_mask = inputs["attention_mask"] # Get the model outputs with torch.no_grad(): outputs = model(**inputs) - logits = outputs.logits # [batch_size, seq_len, vocab_size] + logits = outputs.logits # [batch_size, seq_len, vocab_size] probs = torch.softmax(logits, dim=-1) # [batch_size, seq_len, vocab_size] log_probs = torch.log(probs + 1e-9) # [batch_size, seq_len, vocab_size] - batch_entropy = -(probs * log_probs * attention_mask.unsqueeze(-1)).sum(dim=-1) # [batch_size, seq_len] + batch_entropy = -(probs * log_probs * attention_mask.unsqueeze(-1)).sum( + dim=-1 + ) # [batch_size, seq_len] entrypy_list = [] - for j in tqdm(range(len(batch_entropy)), desc=f"Processing batch {i//batch_size}", leave=False, total=len(batch_entropy)): + for j in tqdm( + range(len(batch_entropy)), + desc=f"Processing batch {i//batch_size}", + leave=False, + total=len(batch_entropy), + ): effective_entry = batch_entropy[j][attention_mask[j] == 1].cpu().numpy() - labels = ["prompt"] * len(tokenizer.encode(prompts[j], add_special_tokens=False)) + labels = ["prompt"] * len( + tokenizer.encode(prompts[j], add_special_tokens=False) + ) labels += ["action"] * (len(effective_entry) - len(labels)) - masks = inputs_with_mask['input_ids'][j][attention_mask[j] == 1] + masks = inputs_with_mask["input_ids"][j][attention_mask[j] == 1] masks = (masks != pad_token_id).cpu().numpy() for k in range(len(labels)): if masks[k] == 0: @@ -130,7 +176,7 @@ def main( if len(entropy_avgs) <= k: entropy_avgs.append([]) entropy_avgs[k].append(avg_entropy[k]) - + entrypy_list.append(effective_entry) vis_paths.append(save_path) @@ -143,7 +189,7 @@ def main( else: print(f"Average obs {i//2} entropy: {avg:.4f}") - + if __name__ == "__main__": fire.Fire(main) @@ -157,4 +203,4 @@ def main( python scripts/visualize_entropy.py --file_path path/to/data.json --model_name Qwen/Qwen2.5-Math-1.5B --batch_size 1 python scripts/visualize_entropy.py --file_path /home/dongfu/WorkSpace/verl-tool/verl_step_records/torl-fsdp-agent-qwen_qwen2.5-math-1.5b-grpo-n16-b128-t1.0-lr1e-6debug/torl-step-1.json --model_name Qwen/Qwen2.5-Math-1.5B --batch_size 2 ``` -""" \ No newline at end of file +""" diff --git a/Agent0/executor_train/verl/examples/data_preprocess/full_hh_rlhf.py b/Agent0/executor_train/verl/examples/data_preprocess/full_hh_rlhf.py index 4625f28..c42db21 100644 --- a/Agent0/executor_train/verl/examples/data_preprocess/full_hh_rlhf.py +++ b/Agent0/executor_train/verl/examples/data_preprocess/full_hh_rlhf.py @@ -62,7 +62,9 @@ def generate_rm_dataset(target_hdfs_path_dir, local_dir="~/data/full_hh_rlh/rm") local_dir = os.path.expanduser(local_dir) os.makedirs(local_dir, exist_ok=True) - for dataset, name in zip([train_dataset, test_dataset], ["train", "test"], strict=True): + for dataset, name in zip( + [train_dataset, test_dataset], ["train", "test"], strict=True + ): output = {"prompt": [], "chosen": [], "rejected": []} for data in tqdm(dataset): # add chosen diff --git a/Agent0/executor_train/verl/examples/data_preprocess/geo3k.py b/Agent0/executor_train/verl/examples/data_preprocess/geo3k.py index 2df225d..7b43dee 100644 --- a/Agent0/executor_train/verl/examples/data_preprocess/geo3k.py +++ b/Agent0/executor_train/verl/examples/data_preprocess/geo3k.py @@ -72,8 +72,12 @@ def process_fn(example, idx): return process_fn - train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True, num_proc=8) - test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True, num_proc=8) + train_dataset = train_dataset.map( + function=make_map_fn("train"), with_indices=True, num_proc=8 + ) + test_dataset = test_dataset.map( + function=make_map_fn("test"), with_indices=True, num_proc=8 + ) local_dir = args.local_dir hdfs_dir = args.hdfs_dir diff --git a/Agent0/executor_train/verl/examples/data_preprocess/geo3k_multiturn_w_tool.py b/Agent0/executor_train/verl/examples/data_preprocess/geo3k_multiturn_w_tool.py index 6e00691..019003c 100644 --- a/Agent0/executor_train/verl/examples/data_preprocess/geo3k_multiturn_w_tool.py +++ b/Agent0/executor_train/verl/examples/data_preprocess/geo3k_multiturn_w_tool.py @@ -88,8 +88,12 @@ def process_fn(example, idx): return process_fn - train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True, num_proc=8) - test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True, num_proc=8) + train_dataset = train_dataset.map( + function=make_map_fn("train"), with_indices=True, num_proc=8 + ) + test_dataset = test_dataset.map( + function=make_map_fn("test"), with_indices=True, num_proc=8 + ) local_dir = args.local_dir hdfs_dir = args.hdfs_dir train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) diff --git a/Agent0/executor_train/verl/examples/data_preprocess/gsm8k.py b/Agent0/executor_train/verl/examples/data_preprocess/gsm8k.py index f39c4f0..ef27042 100644 --- a/Agent0/executor_train/verl/examples/data_preprocess/gsm8k.py +++ b/Agent0/executor_train/verl/examples/data_preprocess/gsm8k.py @@ -46,7 +46,9 @@ def extract_solution(solution_str): train_dataset = dataset["train"] test_dataset = dataset["test"] - instruction_following = 'Let\'s think step by step and output the final answer after "####".' + instruction_following = ( + 'Let\'s think step by step and output the final answer after "####".' + ) # add a row to each data item that represents a unique id def make_map_fn(split): diff --git a/Agent0/executor_train/verl/examples/data_preprocess/gsm8k_multiturn_w_interaction.py b/Agent0/executor_train/verl/examples/data_preprocess/gsm8k_multiturn_w_interaction.py index 718a874..3c56479 100644 --- a/Agent0/executor_train/verl/examples/data_preprocess/gsm8k_multiturn_w_interaction.py +++ b/Agent0/executor_train/verl/examples/data_preprocess/gsm8k_multiturn_w_interaction.py @@ -47,7 +47,9 @@ def extract_solution(solution_str): train_dataset = dataset["train"] test_dataset = dataset["test"] - instruction_following = "Let's think step by step and output the final answer after `####`." + instruction_following = ( + "Let's think step by step and output the final answer after `####`." + ) # add a row to each data item that represents a unique id def make_map_fn(split): diff --git a/Agent0/executor_train/verl/examples/data_preprocess/gsm8k_multiturn_w_tool.py b/Agent0/executor_train/verl/examples/data_preprocess/gsm8k_multiturn_w_tool.py index 400d885..5206a8c 100644 --- a/Agent0/executor_train/verl/examples/data_preprocess/gsm8k_multiturn_w_tool.py +++ b/Agent0/executor_train/verl/examples/data_preprocess/gsm8k_multiturn_w_tool.py @@ -47,7 +47,9 @@ def extract_solution(solution_str): train_dataset = dataset["train"] test_dataset = dataset["test"] - instruction_following = "Let's think step by step and output the final answer after `####`." + instruction_following = ( + "Let's think step by step and output the final answer after `####`." + ) # add a row to each data item that represents a unique id def make_map_fn(split): diff --git a/Agent0/executor_train/verl/examples/data_preprocess/math_dataset.py b/Agent0/executor_train/verl/examples/data_preprocess/math_dataset.py index e2e5d35..429501b 100644 --- a/Agent0/executor_train/verl/examples/data_preprocess/math_dataset.py +++ b/Agent0/executor_train/verl/examples/data_preprocess/math_dataset.py @@ -44,7 +44,9 @@ def extract_solution(solution_str): train_dataset = dataset["train"] test_dataset = dataset["test"] - instruction_following = "Let's think step by step and output the final answer within \\boxed{}." + instruction_following = ( + "Let's think step by step and output the final answer within \\boxed{}." + ) # add a row to each data item that represents a unique id def make_map_fn(split): diff --git a/Agent0/executor_train/verl/examples/data_preprocess/multiturn.py b/Agent0/executor_train/verl/examples/data_preprocess/multiturn.py index 4bf0192..626ab32 100644 --- a/Agent0/executor_train/verl/examples/data_preprocess/multiturn.py +++ b/Agent0/executor_train/verl/examples/data_preprocess/multiturn.py @@ -54,7 +54,10 @@ def main(): "content": "Quantum computing is a type of computing that uses quantum-mechanical phenomena, " "such as superposition and entanglement, to perform operations on data.", }, - {"role": "user", "content": "How is it different from classical computing?"}, + { + "role": "user", + "content": "How is it different from classical computing?", + }, { "role": "assistant", "content": "Classical computing uses bits that are either 0 or 1, while quantum computing uses " @@ -69,7 +72,10 @@ def main(): { "messages": [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Write a simple Python function to calculate factorial."}, + { + "role": "user", + "content": "Write a simple Python function to calculate factorial.", + }, { "role": "assistant", "content": ( diff --git a/Agent0/executor_train/verl/examples/data_preprocess/preprocess_search_r1_dataset.py b/Agent0/executor_train/verl/examples/data_preprocess/preprocess_search_r1_dataset.py index a0c10d5..19d08eb 100644 --- a/Agent0/executor_train/verl/examples/data_preprocess/preprocess_search_r1_dataset.py +++ b/Agent0/executor_train/verl/examples/data_preprocess/preprocess_search_r1_dataset.py @@ -25,7 +25,9 @@ from verl.utils.hdfs_io import copy, makedirs # Setup logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) # Configuration constants @@ -58,7 +60,10 @@ def process_single_row(row, current_split_name, row_index): # Build prompt structure user_content = user_content_prefix.rstrip("\n") + question - prompt = [{"role": "system", "content": system_content}, {"role": "user", "content": user_content}] + prompt = [ + {"role": "system", "content": system_content}, + {"role": "user", "content": user_content}, + ] # Extract ground truth from reward_model or fallback to golden_answers reward_model_data = row.get("reward_model") @@ -73,7 +78,11 @@ def process_single_row(row, current_split_name, row_index): # Build tools kwargs structure tools_kwargs = { "search": { - "create_kwargs": {"ground_truth": ground_truth, "question": question, "data_source": data_source_tagged} + "create_kwargs": { + "ground_truth": ground_truth, + "question": question, + "data_source": data_source_tagged, + } } } @@ -126,18 +135,24 @@ def main(): logger.info(f"Loaded {len(df_raw)} rows from {parquet_filename}") def apply_process_row(row, split_name=split): - return process_single_row(row, current_split_name=split_name, row_index=row.name) + return process_single_row( + row, current_split_name=split_name, row_index=row.name + ) df_processed = df_raw.apply(apply_process_row, axis=1) # Save processed DataFrame output_file_path = os.path.join(local_save_dir, f"{split}.parquet") df_processed.to_parquet(output_file_path, index=False) - logger.info(f"Saved {len(df_processed)} processed rows to {output_file_path}") + logger.info( + f"Saved {len(df_processed)} processed rows to {output_file_path}" + ) processed_files.append(output_file_path) except EntryNotFoundError: - logger.warning(f"{parquet_filename} not found in repository {args.hf_repo_id}") + logger.warning( + f"{parquet_filename} not found in repository {args.hf_repo_id}" + ) except Exception as e: logger.error(f"Error processing {split} split: {e}") @@ -145,7 +160,9 @@ def apply_process_row(row, split_name=split): logger.warning("No data was processed or saved") return - logger.info(f"Successfully processed {len(processed_files)} files to {local_save_dir}") + logger.info( + f"Successfully processed {len(processed_files)} files to {local_save_dir}" + ) # Copy to HDFS if specified if args.hdfs_dir: @@ -158,16 +175,24 @@ def apply_process_row(row, split_name=split): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Download Search-R1 from HuggingFace, process, and save to Parquet.") + parser = argparse.ArgumentParser( + description="Download Search-R1 from HuggingFace, process, and save to Parquet." + ) parser.add_argument( - "--hf_repo_id", default="PeterJinGo/nq_hotpotqa_train", help="HuggingFace dataset repository ID." + "--hf_repo_id", + default="PeterJinGo/nq_hotpotqa_train", + help="HuggingFace dataset repository ID.", ) parser.add_argument( "--local_dir", default="~/data/searchR1_processed_direct", help="Local directory to save the processed Parquet files.", ) - parser.add_argument("--hdfs_dir", default=None, help="Optional HDFS directory to copy the Parquet files to.") + parser.add_argument( + "--hdfs_dir", + default=None, + help="Optional HDFS directory to copy the Parquet files to.", + ) args = parser.parse_args() diff --git a/Agent0/executor_train/verl/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py b/Agent0/executor_train/verl/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py index 6fe5549..b8a7f0c 100644 --- a/Agent0/executor_train/verl/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py +++ b/Agent0/executor_train/verl/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py @@ -20,9 +20,18 @@ from huggingface_hub import hf_hub_download -parser = argparse.ArgumentParser(description="Download files from a Hugging Face dataset repository.") -parser.add_argument("--repo_id", type=str, default="PeterJinGo/wiki-18-e5-index", help="Hugging Face repository ID") -parser.add_argument("--save_path", type=str, required=True, help="Local directory to save files") +parser = argparse.ArgumentParser( + description="Download files from a Hugging Face dataset repository." +) +parser.add_argument( + "--repo_id", + type=str, + default="PeterJinGo/wiki-18-e5-index", + help="Hugging Face repository ID", +) +parser.add_argument( + "--save_path", type=str, required=True, help="Local directory to save files" +) args = parser.parse_args() diff --git a/Agent0/executor_train/verl/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py b/Agent0/executor_train/verl/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py index 2f67c14..dca4cf7 100644 --- a/Agent0/executor_train/verl/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py +++ b/Agent0/executor_train/verl/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py @@ -32,7 +32,9 @@ def load_corpus(corpus_path: str): - corpus = datasets.load_dataset("json", data_files=corpus_path, split="train", num_proc=4) + corpus = datasets.load_dataset( + "json", data_files=corpus_path, split="train", num_proc=4 + ) return corpus @@ -47,13 +49,19 @@ def load_model(model_path: str, use_fp16: bool = False): model.cuda() if use_fp16: model = model.half() - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=True + ) return model, tokenizer -def pooling(pooler_output, last_hidden_state, attention_mask=None, pooling_method="mean"): +def pooling( + pooler_output, last_hidden_state, attention_mask=None, pooling_method="mean" +): if pooling_method == "mean": - last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) + last_hidden = last_hidden_state.masked_fill( + ~attention_mask[..., None].bool(), 0.0 + ) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] elif pooling_method == "cls": return last_hidden_state[:, 0] @@ -71,7 +79,9 @@ def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16) self.max_length = max_length self.use_fp16 = use_fp16 - self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16) + self.model, self.tokenizer = load_model( + model_path=model_path, use_fp16=use_fp16 + ) self.model.eval() @torch.no_grad() @@ -89,25 +99,35 @@ def encode(self, query_list: list[str], is_query=True) -> np.ndarray: if "bge" in self.model_name.lower(): if is_query: query_list = [ - f"Represent this sentence for searching relevant passages: {query}" for query in query_list + f"Represent this sentence for searching relevant passages: {query}" + for query in query_list ] inputs = self.tokenizer( - query_list, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt" + query_list, + max_length=self.max_length, + padding=True, + truncation=True, + return_tensors="pt", ) inputs = {k: v.cuda() for k, v in inputs.items()} if "T5" in type(self.model).__name__: # T5-based retrieval model - decoder_input_ids = torch.zeros((inputs["input_ids"].shape[0], 1), dtype=torch.long).to( - inputs["input_ids"].device + decoder_input_ids = torch.zeros( + (inputs["input_ids"].shape[0], 1), dtype=torch.long + ).to(inputs["input_ids"].device) + output = self.model( + **inputs, decoder_input_ids=decoder_input_ids, return_dict=True ) - output = self.model(**inputs, decoder_input_ids=decoder_input_ids, return_dict=True) query_emb = output.last_hidden_state[:, 0, :] else: output = self.model(**inputs, return_dict=True) query_emb = pooling( - output.pooler_output, output.last_hidden_state, inputs["attention_mask"], self.pooling_method + output.pooler_output, + output.last_hidden_state, + inputs["attention_mask"], + self.pooling_method, ) if "dpr" not in self.model_name.lower(): query_emb = torch.nn.functional.normalize(query_emb, dim=-1) @@ -139,7 +159,9 @@ def _batch_search(self, query_list: list[str], num: int, return_score: bool): def search(self, query: str, num: int = None, return_score: bool = False): return self._search(query, num, return_score) - def batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): + def batch_search( + self, query_list: list[str], num: int = None, return_score: bool = False + ): return self._batch_search(query_list, num, return_score) @@ -173,7 +195,10 @@ def _search(self, query: str, num: int = None, return_score: bool = False): hits = hits[:num] if self.contain_doc: - all_contents = [json.loads(self.searcher.doc(hit.docid).raw())["contents"] for hit in hits] + all_contents = [ + json.loads(self.searcher.doc(hit.docid).raw())["contents"] + for hit in hits + ] results = [ { "title": content.split("\n")[0].strip('"'), @@ -190,7 +215,9 @@ def _search(self, query: str, num: int = None, return_score: bool = False): else: return results - def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): + def _batch_search( + self, query_list: list[str], num: int = None, return_score: bool = False + ): results = [] scores = [] for query in query_list: @@ -237,7 +264,9 @@ def _search(self, query: str, num: int = None, return_score: bool = False): else: return results - def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): + def _batch_search( + self, query_list: list[str], num: int = None, return_score: bool = False + ): if isinstance(query_list, str): query_list = [query_list] if num is None: @@ -245,7 +274,9 @@ def _batch_search(self, query_list: list[str], num: int = None, return_score: bo results = [] scores = [] - for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc="Retrieval process: "): + for start_idx in tqdm( + range(0, len(query_list), self.batch_size), desc="Retrieval process: " + ): query_batch = query_list[start_idx : start_idx + self.batch_size] batch_emb = self.encoder.encode(query_batch) batch_scores, batch_idxs = self.index.search(batch_emb, k=num) @@ -256,12 +287,21 @@ def _batch_search(self, query_list: list[str], num: int = None, return_score: bo flat_idxs = sum(batch_idxs, []) batch_results = load_docs(self.corpus, flat_idxs) # chunk them back - batch_results = [batch_results[i * num : (i + 1) * num] for i in range(len(batch_idxs))] + batch_results = [ + batch_results[i * num : (i + 1) * num] for i in range(len(batch_idxs)) + ] results.extend(batch_results) scores.extend(batch_scores) - del batch_emb, batch_scores, batch_idxs, query_batch, flat_idxs, batch_results + del ( + batch_emb, + batch_scores, + batch_idxs, + query_batch, + flat_idxs, + batch_results, + ) torch.cuda.empty_cache() if return_score: @@ -376,7 +416,10 @@ def retrieve_endpoint(request: QueryRequest): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Launch the local faiss retriever.") parser.add_argument( - "--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file." + "--index_path", + type=str, + default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", + help="Corpus indexing file.", ) parser.add_argument( "--corpus_path", @@ -384,12 +427,24 @@ def retrieve_endpoint(request: QueryRequest): default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", help="Local corpus file.", ) - parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.") - parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.") parser.add_argument( - "--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model." + "--topk", + type=int, + default=3, + help="Number of retrieved passages for one query.", + ) + parser.add_argument( + "--retriever_name", type=str, default="e5", help="Name of the retriever model." + ) + parser.add_argument( + "--retriever_model", + type=str, + default="intfloat/e5-base-v2", + help="Path of the retriever model.", + ) + parser.add_argument( + "--faiss_gpu", action="store_true", help="Use GPU for computation" ) - parser.add_argument("--faiss_gpu", action="store_true", help="Use GPU for computation") args = parser.parse_args() diff --git a/Agent0/executor_train/verl/examples/split_placement/main_ppo_split.py b/Agent0/executor_train/verl/examples/split_placement/main_ppo_split.py index c438e7a..6eb7a5d 100644 --- a/Agent0/executor_train/verl/examples/split_placement/main_ppo_split.py +++ b/Agent0/executor_train/verl/examples/split_placement/main_ppo_split.py @@ -57,11 +57,15 @@ def __call__(self, data: DataProto, return_dict: bool = False): prompt_length = prompt_ids.shape[-1] - valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() + valid_prompt_length = data_item.batch["attention_mask"][ + :prompt_length + ].sum() valid_prompt_ids = prompt_ids[-valid_prompt_length:] response_ids = data_item.batch["responses"] - valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_length = data_item.batch["attention_mask"][ + prompt_length: + ].sum() valid_response_ids = response_ids[:valid_response_length] # decode @@ -74,7 +78,9 @@ def __call__(self, data: DataProto, return_dict: bool = False): data_source = data_item.non_tensor_batch["data_source"] compute_score_fn = _select_rm_score_fn(data_source) - score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth) + score = compute_score_fn( + solution_str=sequences_str, ground_truth=ground_truth + ) reward_tensor[i, valid_response_length - 1] = score if data_source not in already_print_data_sources: @@ -95,7 +101,9 @@ def main(config): if not ray.is_initialized(): # this is for local ray cluster ray.init( - runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}, + runtime_env={ + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"} + }, num_cpus=config.ray_init.num_cpus, ) @@ -111,7 +119,9 @@ def main_task(config): from verl.utils.fs import copy_to_local - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + pprint( + OmegaConf.to_container(config, resolve=True) + ) # resolve=True will eval symbol values OmegaConf.resolve(config) # download the checkpoint from hdfs @@ -152,13 +162,17 @@ def main_task(config): critic_pool_id = "critic_pool" if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: resource_pool_spec = { - actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, - critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] + * config.trainer.nnodes, + critic_pool_id: [config.trainer.n_gpus_per_node // 2] + * config.trainer.nnodes, } else: resource_pool_spec = { - actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), - critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] + * (config.trainer.nnodes // 2), + critic_pool_id: [config.trainer.n_gpus_per_node] + * (config.trainer.nnodes // 2), } print(f"resource_pool_spec: {resource_pool_spec}") mapping = { @@ -192,7 +206,9 @@ def main_task(config): # Note that we always use function-based RM for validation val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager = ResourcePoolManager( + resource_pool_spec=resource_pool_spec, mapping=mapping + ) RayPPOTrainer.fit = fit trainer = RayPPOTrainer( diff --git a/Agent0/executor_train/verl/examples/split_placement/split_monkey_patch.py b/Agent0/executor_train/verl/examples/split_placement/split_monkey_patch.py index ef58509..ebdc1a4 100644 --- a/Agent0/executor_train/verl/examples/split_placement/split_monkey_patch.py +++ b/Agent0/executor_train/verl/examples/split_placement/split_monkey_patch.py @@ -59,7 +59,9 @@ def fit(self): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + if self.val_reward_fn is not None and self.config.trainer.get( + "val_before_train", True + ): val_metrics = self._validate() pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) @@ -78,13 +80,17 @@ def fit(self): batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation - gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) + gen_batch = batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"] + ) is_last_step = self.global_steps >= self.total_training_steps with marked_timer("step", timing_raw): # generate a batch with marked_timer("gen", timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + gen_batch_output = self.actor_rollout_wg.generate_sequences( + gen_batch + ) timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) @@ -92,7 +98,9 @@ def fit(self): with marked_timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + gen_baseline_output = self.actor_rollout_wg.generate_sequences( + gen_baseline_batch + ) batch = batch.union(gen_baseline_output) reward_baseline_tensor = self.reward_fn(batch) @@ -108,7 +116,10 @@ def fit(self): [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object ) # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) batch = batch.union(gen_batch_output) # Balance the number of valid tokens across DP ranks. @@ -119,7 +130,9 @@ def fit(self): self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum( + batch.batch["attention_mask"], dim=-1 + ).tolist() # recompute old_log_probs with marked_timer("old_log_prob", timing_raw): @@ -154,14 +167,20 @@ def fit(self): # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: batch, kl_metrics = apply_kl_penalty( - batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + batch, + kl_ctrl=self.kl_ctrl_in_reward, + kl_penalty=self.config.algorithm.kl_penalty, ) metrics.update(kl_metrics) else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + batch.batch["token_level_rewards"] = batch.batch[ + "token_level_scores" + ] # compute advantages, executed on the driver process - norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) batch = compute_advantage( batch, adv_estimator=self.config.algorithm.adv_estimator, @@ -187,19 +206,26 @@ def fit(self): # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class with marked_timer("update_actor_critic", timing_raw): critic_output = critic_output.get() - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + critic_output_metrics = reduce_metrics( + critic_output.meta_info["metrics"] + ) metrics.update(critic_output_metrics) if actor_output is not None: actor_output = actor_output.get() - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + actor_output_metrics = reduce_metrics( + actor_output.meta_info["metrics"] + ) metrics.update(actor_output_metrics) # validate if ( self.val_reward_fn is not None and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + and ( + is_last_step + or self.global_steps % self.config.trainer.test_freq == 0 + ) ): with marked_timer("testing", timing_raw): val_metrics: dict = self._validate() @@ -208,13 +234,16 @@ def fit(self): metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 ): with marked_timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update( + compute_data_metrics(batch=batch, use_critic=self.use_critic) + ) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: make a canonical logger that supports various backend diff --git a/Agent0/executor_train/verl/recipe/char_count/create_dataset.py b/Agent0/executor_train/verl/recipe/char_count/create_dataset.py index 47571e0..c011ba4 100644 --- a/Agent0/executor_train/verl/recipe/char_count/create_dataset.py +++ b/Agent0/executor_train/verl/recipe/char_count/create_dataset.py @@ -138,9 +138,21 @@ def create_prompt_response(min_length=3, max_length=5): sft_test_dataset.to_parquet(os.path.join(folder, "test.parquet")) # build RL dataset - rl_train_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []} - - rl_test_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []} + rl_train_dataset = { + "prompt": [], + "data_source": [], + "ability": [], + "reward_model": [], + "extra_info": [], + } + + rl_test_dataset = { + "prompt": [], + "data_source": [], + "ability": [], + "reward_model": [], + "extra_info": [], + } from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed @@ -158,7 +170,10 @@ def create_prompt_response(min_length=3, max_length=5): rl_train_dataset["data_source"].append("char_count") rl_train_dataset["ability"].append("other") rl_train_dataset["reward_model"].append( - {"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))} + { + "style": "rule", + "ground_truth": remove_boxed(last_boxed_only_string(response)), + } ) rl_train_dataset["extra_info"].append({"response": response}) @@ -176,7 +191,10 @@ def create_prompt_response(min_length=3, max_length=5): rl_test_dataset["data_source"].append("char_count") rl_test_dataset["ability"].append("other") rl_test_dataset["reward_model"].append( - {"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))} + { + "style": "rule", + "ground_truth": remove_boxed(last_boxed_only_string(response)), + } ) rl_test_dataset["extra_info"].append({"response": response}) diff --git a/Agent0/executor_train/verl/recipe/char_count/reward_function.py b/Agent0/executor_train/verl/recipe/char_count/reward_function.py index 9bdffe2..6635651 100644 --- a/Agent0/executor_train/verl/recipe/char_count/reward_function.py +++ b/Agent0/executor_train/verl/recipe/char_count/reward_function.py @@ -19,7 +19,9 @@ from verl.utils.reward_score import math -def char_count_reward_function(data_source, solution_str, ground_truth, extra_info=None): +def char_count_reward_function( + data_source, solution_str, ground_truth, extra_info=None +): try: last_boxed_string = math.last_boxed_only_string(solution_str) if last_boxed_string is None: diff --git a/Agent0/executor_train/verl/recipe/dapo/dapo_ray_trainer.py b/Agent0/executor_train/verl/recipe/dapo/dapo_ray_trainer.py index d3d6dbc..117613d 100644 --- a/Agent0/executor_train/verl/recipe/dapo/dapo_ray_trainer.py +++ b/Agent0/executor_train/verl/recipe/dapo/dapo_ray_trainer.py @@ -73,7 +73,9 @@ def fit(self): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + if self.val_reward_fn is not None and self.config.trainer.get( + "val_before_train", True + ): val_metrics = self._validate() assert val_metrics, f"{val_metrics=}" pprint(f"Initial validation metrics: {val_metrics}") @@ -82,7 +84,11 @@ def fit(self): return # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + progress_bar = tqdm( + total=self.total_training_steps, + initial=self.global_steps, + desc="Training Progress", + ) # we start from step 1 self.global_steps += 1 @@ -124,14 +130,19 @@ def fit(self): batch_keys=["input_ids", "attention_mask", "position_ids"], non_tensor_batch_keys=["raw_prompt_ids"], ) - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + gen_batch = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) is_last_step = self.global_steps >= self.total_training_steps with marked_timer("step", timing_raw): # generate a batch with marked_timer("gen", timing_raw, "red"): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + gen_batch_output = self.actor_rollout_wg.generate_sequences( + gen_batch + ) timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) @@ -139,23 +150,33 @@ def fit(self): with marked_timer("gen_max", timing_raw, "red"): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + gen_baseline_output = ( + self.actor_rollout_wg.generate_sequences( + gen_baseline_batch + ) + ) new_batch = new_batch.union(gen_baseline_output) reward_baseline_tensor = self.reward_fn(new_batch) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + new_batch.pop( + batch_keys=list(gen_baseline_output.batch.keys()) + ) new_batch.batch["reward_baselines"] = reward_baseline_tensor del gen_baseline_batch, gen_baseline_output new_batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], + dtype=object, ) # repeat to align with repeated responses in rollout - new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + new_batch = new_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) new_batch = new_batch.union(gen_batch_output) with marked_timer("reward", timing_raw, "yellow"): @@ -172,7 +193,9 @@ def fit(self): try: reward_result = self.reward_fn(new_batch, return_dict=True) reward_tensor = reward_result["reward_tensor"] - reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) + reward_extra_infos_dict = reward_result.get( + "reward_extra_info", {} + ) except Exception as e: print(f"Error in reward_fn: {e}") reward_tensor = self.reward_fn(new_batch) @@ -182,19 +205,26 @@ def fit(self): if reward_extra_infos_dict: new_batch.non_tensor_batch.update( - {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + { + k: np.array(v) + for k, v in reward_extra_infos_dict.items() + } ) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: new_batch, kl_metrics = apply_kl_penalty( - new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + new_batch, + kl_ctrl=self.kl_ctrl_in_reward, + kl_penalty=self.config.algorithm.kl_penalty, ) metrics.update( kl_metrics ) # TODO: This will be cleared if we use multiple genenration batches else: - new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] + new_batch.batch["token_level_rewards"] = new_batch.batch[ + "token_level_scores" + ] if not self.config.algorithm.filter_groups.enable: batch = new_batch @@ -204,17 +234,23 @@ def fit(self): if metric_name == "seq_final_reward": # Turn to numpy for easier filtering new_batch.non_tensor_batch["seq_final_reward"] = ( - new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + new_batch.batch["token_level_rewards"] + .sum(dim=-1) + .numpy() ) elif metric_name == "seq_reward": new_batch.non_tensor_batch["seq_reward"] = ( - new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + new_batch.batch["token_level_scores"] + .sum(dim=-1) + .numpy() ) # Collect the sequence reward for each trajectory prompt_uid2metric_vals = defaultdict(list) for uid, metric_val in zip( - new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True + new_batch.non_tensor_batch["uid"], + new_batch.non_tensor_batch[metric_name], + strict=True, ): prompt_uid2metric_vals[uid].append(metric_val) @@ -230,18 +266,29 @@ def fit(self): num_prompt_in_batch += len(kept_prompt_uids) kept_traj_idxs = [] - for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): + for idx, traj_from_prompt_uid in enumerate( + new_batch.non_tensor_batch["uid"] + ): if traj_from_prompt_uid in kept_prompt_uids: kept_traj_idxs.append(idx) new_batch = new_batch[kept_traj_idxs] - batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) + batch = ( + new_batch + if batch is None + else DataProto.concat([batch, new_batch]) + ) prompt_bsz = self.config.data.train_batch_size if num_prompt_in_batch < prompt_bsz: print(f"{num_prompt_in_batch=} < {prompt_bsz=}") - max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches - if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + max_num_gen_batches = ( + self.config.algorithm.filter_groups.max_num_gen_batches + ) + if ( + max_num_gen_batches <= 0 + or num_gen_batches < max_num_gen_batches + ): print(f"{num_gen_batches=}. Keep generating...") progress_bar.update(1) continue @@ -253,7 +300,10 @@ def fit(self): ) else: # Align the batch - traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + traj_bsz = ( + self.config.data.train_batch_size + * self.config.actor_rollout_ref.rollout.n + ) batch = batch[:traj_bsz] # === Updating === @@ -269,16 +319,26 @@ def fit(self): self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum( + batch.batch["attention_mask"], dim=-1 + ).tolist() # recompute old_log_probs with marked_timer("old_log_prob", timing_raw, "blue"): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + loss_agg_mode = ( + self.config.actor_rollout_ref.actor.loss_agg_mode + ) + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=loss_agg_mode, + ) + old_log_prob_metrics = { + "actor/entropy": entropy_agg.detach().item() + } metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) @@ -286,7 +346,9 @@ def fit(self): if self.use_reference_policy: # compute reference log_prob with marked_timer("ref", timing_raw, "olive"): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob( + batch + ) batch = batch.union(ref_log_prob) # compute values @@ -297,7 +359,9 @@ def fit(self): with marked_timer("adv", timing_raw, "brown"): # compute advantages, executed on the driver process - norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) batch = compute_advantage( batch, adv_estimator=self.config.algorithm.adv_estimator, @@ -311,7 +375,9 @@ def fit(self): if self.use_critic: with marked_timer("update_critic", timing_raw, "pink"): critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + critic_output_metrics = reduce_metrics( + critic_output.meta_info["metrics"] + ) metrics.update(critic_output_metrics) # implement critic warmup @@ -319,14 +385,19 @@ def fit(self): # update actor with marked_timer("update_actor", timing_raw, "red"): actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + actor_output_metrics = reduce_metrics( + actor_output.meta_info["metrics"] + ) metrics.update(actor_output_metrics) # validate if ( self.val_reward_fn is not None and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + and ( + is_last_step + or self.global_steps % self.config.trainer.test_freq == 0 + ) ): with marked_timer("testing", timing_raw, "green"): val_metrics: dict = self._validate() @@ -335,7 +406,8 @@ def fit(self): metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 ): with marked_timer("save_checkpoint", timing_raw, "green"): self._save_checkpoint() @@ -351,11 +423,19 @@ def fit(self): self.rm_wg.stop_profile() # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + metrics.update( + compute_data_metrics(batch=batch, use_critic=self.use_critic) + ) + metrics.update( + compute_timing_metrics(batch=batch, timing_raw=timing_raw) + ) # TODO: implement actual tflpo and theoretical tflpo n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + metrics.update( + compute_throughout_metrics( + batch=batch, timing_raw=timing_raw, n_gpus=n_gpus + ) + ) timing_raw = defaultdict(float) # clear timing metrics["train/num_gen_batches"] = num_gen_batches diff --git a/Agent0/executor_train/verl/recipe/dapo/main_dapo.py b/Agent0/executor_train/verl/recipe/dapo/main_dapo.py index 1ee7359..afda3d8 100644 --- a/Agent0/executor_train/verl/recipe/dapo/main_dapo.py +++ b/Agent0/executor_train/verl/recipe/dapo/main_dapo.py @@ -38,7 +38,11 @@ def run_ppo(config) -> None: # this is for local ray cluster ray.init( runtime_env={ - "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + } }, num_cpus=config.ray_init.num_cpus, ) @@ -48,7 +52,9 @@ def run_ppo(config) -> None: and OmegaConf.select(config.trainer, "profile_steps") is not None and len(OmegaConf.select(config.trainer, "profile_steps")) > 0 ): - nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) + nsight_options = OmegaConf.to_container( + config.trainer.controller_nsight_options + ) runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() else: runner = TaskRunner.remote() @@ -67,7 +73,9 @@ def run(self, config): print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + pprint( + OmegaConf.to_container(config, resolve=True) + ) # resolve=True will eval symbol values OmegaConf.resolve(config) # download the checkpoint from hdfs @@ -77,7 +85,9 @@ def run(self, config): from verl.utils import hf_processor, hf_tokenizer tokenizer = hf_tokenizer(local_path) - processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + processor = hf_processor( + local_path, use_fast=True + ) # used for multimodal LLM, could be none # define worker classes if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: @@ -90,7 +100,10 @@ def run(self, config): elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ( + ActorRolloutRefWorker, + CriticWorker, + ) ray_worker_group_cls = NVMegatronRayWorkerGroup @@ -130,7 +143,10 @@ def run(self, config): mapping[Role.RewardModel] = global_pool_id # reference model - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + if ( + config.algorithm.use_kl_in_reward + or config.actor_rollout_ref.actor.use_kl_loss + ): role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id @@ -160,7 +176,9 @@ def run(self, config): max_resp_len=config.data.max_response_length, overlong_buffer_cfg=config.reward_model.overlong_buffer, ) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager = ResourcePoolManager( + resource_pool_spec=resource_pool_spec, mapping=mapping + ) trainer = RayDAPOTrainer( config=config, diff --git a/Agent0/executor_train/verl/recipe/entropy/entropy_ray_trainer.py b/Agent0/executor_train/verl/recipe/entropy/entropy_ray_trainer.py index 0b0b043..0aa18b6 100644 --- a/Agent0/executor_train/verl/recipe/entropy/entropy_ray_trainer.py +++ b/Agent0/executor_train/verl/recipe/entropy/entropy_ray_trainer.py @@ -72,7 +72,9 @@ def fit(self): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + if self.val_reward_fn is not None and self.config.trainer.get( + "val_before_train", True + ): val_metrics = self._validate() assert val_metrics, f"{val_metrics=}" pprint(f"Initial validation metrics: {val_metrics}") @@ -81,7 +83,11 @@ def fit(self): return # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + progress_bar = tqdm( + total=self.total_training_steps, + initial=self.global_steps, + desc="Training Progress", + ) # we start from step 1 self.global_steps += 1 @@ -101,14 +107,21 @@ def fit(self): if "multi_modal_inputs" in new_batch.non_tensor_batch.keys(): gen_batch = new_batch.pop( batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], + non_tensor_batch_keys=[ + "raw_prompt_ids", + "multi_modal_data", + "multi_modal_inputs", + ], ) else: gen_batch = new_batch.pop( batch_keys=["input_ids", "attention_mask", "position_ids"], non_tensor_batch_keys=["raw_prompt_ids"], ) - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + gen_batch = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) is_last_step = self.global_steps >= self.total_training_steps @@ -118,31 +131,45 @@ def fit(self): # gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) with simple_timer("gen", timing_raw): if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + gen_batch_output = self.actor_rollout_wg.generate_sequences( + gen_batch + ) else: - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + gen_batch_output = ( + self.async_rollout_manager.generate_sequences(gen_batch) + ) if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: with simple_timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + gen_baseline_output = ( + self.actor_rollout_wg.generate_sequences( + gen_baseline_batch + ) + ) new_batch = new_batch.union(gen_baseline_output) reward_baseline_tensor = self.reward_fn(new_batch) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + new_batch.pop( + batch_keys=list(gen_baseline_output.batch.keys()) + ) new_batch.batch["reward_baselines"] = reward_baseline_tensor del gen_baseline_batch, gen_baseline_output new_batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], + dtype=object, ) # repeat to align with repeated responses in rollout - new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + new_batch = new_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) new_batch = new_batch.union(gen_batch_output) with simple_timer("reward", timing_raw): @@ -170,19 +197,26 @@ def fit(self): print(f"{list(reward_extra_infos_dict.keys())=}") if reward_extra_infos_dict: new_batch.non_tensor_batch.update( - {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + { + k: np.array(v) + for k, v in reward_extra_infos_dict.items() + } ) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: new_batch, kl_metrics = apply_kl_penalty( - new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + new_batch, + kl_ctrl=self.kl_ctrl_in_reward, + kl_penalty=self.config.algorithm.kl_penalty, ) metrics.update( kl_metrics ) # TODO: This will be cleared if we use multiple genenration batches else: - new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] + new_batch.batch["token_level_rewards"] = new_batch.batch[ + "token_level_scores" + ] if not self.config.algorithm.filter_groups.enable: batch = new_batch @@ -192,17 +226,23 @@ def fit(self): if metric_name == "seq_final_reward": # Turn to numpy for easier filtering new_batch.non_tensor_batch["seq_final_reward"] = ( - new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + new_batch.batch["token_level_rewards"] + .sum(dim=-1) + .numpy() ) elif metric_name == "seq_reward": new_batch.non_tensor_batch["seq_reward"] = ( - new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + new_batch.batch["token_level_scores"] + .sum(dim=-1) + .numpy() ) # Collect the sequence reward for each trajectory prompt_uid2metric_vals = defaultdict(list) for uid, metric_val in zip( - new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True + new_batch.non_tensor_batch["uid"], + new_batch.non_tensor_batch[metric_name], + strict=True, ): prompt_uid2metric_vals[uid].append(metric_val) @@ -218,18 +258,29 @@ def fit(self): num_prompt_in_batch += len(kept_prompt_uids) kept_traj_idxs = [] - for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): + for idx, traj_from_prompt_uid in enumerate( + new_batch.non_tensor_batch["uid"] + ): if traj_from_prompt_uid in kept_prompt_uids: kept_traj_idxs.append(idx) new_batch = new_batch[kept_traj_idxs] - batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) + batch = ( + new_batch + if batch is None + else DataProto.concat([batch, new_batch]) + ) prompt_bsz = self.config.data.train_batch_size if num_prompt_in_batch < prompt_bsz: print(f"{num_prompt_in_batch=} < {prompt_bsz=}") - max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches - if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + max_num_gen_batches = ( + self.config.algorithm.filter_groups.max_num_gen_batches + ) + if ( + max_num_gen_batches <= 0 + or num_gen_batches < max_num_gen_batches + ): print(f"{num_gen_batches=}. Keep generating...") continue else: @@ -240,7 +291,10 @@ def fit(self): ) else: # Align the batch - traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + traj_bsz = ( + self.config.data.train_batch_size + * self.config.actor_rollout_ref.rollout.n + ) print( f"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. " f"Collecting finished." @@ -258,7 +312,9 @@ def fit(self): self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum( + batch.batch["attention_mask"], dim=-1 + ).tolist() # recompute old_log_probs with simple_timer("old_log_prob", timing_raw): @@ -268,7 +324,9 @@ def fit(self): if self.use_reference_policy: # compute reference log_prob with simple_timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob( + batch + ) batch = batch.union(ref_log_prob) # compute values @@ -279,7 +337,9 @@ def fit(self): with simple_timer("adv", timing_raw): # compute advantages, executed on the driver process - norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) batch = compute_advantage( batch, adv_estimator=self.config.algorithm.adv_estimator, @@ -293,7 +353,9 @@ def fit(self): if self.use_critic: with simple_timer("update_critic", timing_raw): critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + critic_output_metrics = reduce_metrics( + critic_output.meta_info["metrics"] + ) metrics.update(critic_output_metrics) # implement critic warmup @@ -301,14 +363,19 @@ def fit(self): # update actor with simple_timer("update_actor", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + actor_output_metrics = reduce_metrics( + actor_output.meta_info["metrics"] + ) metrics.update(actor_output_metrics) # validate if ( self.val_reward_fn is not None and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + and ( + is_last_step + or self.global_steps % self.config.trainer.test_freq == 0 + ) ): with simple_timer("testing", timing_raw): val_metrics: dict = self._validate() @@ -317,17 +384,26 @@ def fit(self): metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 ): with simple_timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + metrics.update( + compute_data_metrics(batch=batch, use_critic=self.use_critic) + ) + metrics.update( + compute_timing_metrics(batch=batch, timing_raw=timing_raw) + ) # TODO: implement actual tflpo and theoretical tflpo n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + metrics.update( + compute_throughout_metrics( + batch=batch, timing_raw=timing_raw, n_gpus=n_gpus + ) + ) timing_raw = defaultdict(float) # clear timing metrics["train/num_gen_batches"] = num_gen_batches diff --git a/Agent0/executor_train/verl/recipe/entropy/main_entropy.py b/Agent0/executor_train/verl/recipe/entropy/main_entropy.py index a8bb0cb..756290c 100644 --- a/Agent0/executor_train/verl/recipe/entropy/main_entropy.py +++ b/Agent0/executor_train/verl/recipe/entropy/main_entropy.py @@ -71,7 +71,9 @@ def run(self, config): from verl.utils.fs import copy_to_local - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + pprint( + OmegaConf.to_container(config, resolve=True) + ) # resolve=True will eval symbol values OmegaConf.resolve(config) # download the checkpoint from hdfs @@ -82,13 +84,19 @@ def run(self, config): trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + processor = hf_processor( + local_path, use_fast=True + ) # used for multimodal LLM, could be none # define worker classes if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ( + ActorRolloutRefWorker, + AsyncActorRolloutRefWorker, + CriticWorker, + ) actor_rollout_cls = ( AsyncActorRolloutRefWorker @@ -100,7 +108,10 @@ def run(self, config): elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ( + ActorRolloutRefWorker, + CriticWorker, + ) actor_rollout_cls = ActorRolloutRefWorker ray_worker_group_cls = NVMegatronRayWorkerGroup @@ -141,7 +152,10 @@ def run(self, config): mapping[Role.RewardModel] = global_pool_id # use reference model - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + if ( + config.algorithm.use_kl_in_reward + or config.actor_rollout_ref.actor.use_kl_loss + ): role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id @@ -151,15 +165,26 @@ def run(self, config): } cfg_reward_kwargs = config.reward_model.get("reward_kwargs", {}) reward_fn = load_reward_manager( - config, tokenizer, num_examine=0, **OmegaConf.merge(OmegaConf.create(reward_kwargs), cfg_reward_kwargs) + config, + tokenizer, + num_examine=0, + **OmegaConf.merge(OmegaConf.create(reward_kwargs), cfg_reward_kwargs), + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **reward_kwargs + ) + resource_pool_manager = ResourcePoolManager( + resource_pool_spec=resource_pool_spec, mapping=mapping ) - val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **reward_kwargs) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) from verl.utils.dataset.rl_dataset import collate_fn - train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) - val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + train_dataset = create_rl_dataset( + config.data.train_files, config.data, tokenizer, processor + ) + val_dataset = create_rl_dataset( + config.data.val_files, config.data, tokenizer, processor + ) train_sampler = create_rl_sampler(config.data, train_dataset) trainer = RayEntropyTrainer( config=config, @@ -194,10 +219,15 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor): from verl.utils.dataset.rl_dataset import RLHFDataset - if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + if ( + "custom_cls" in data_config + and data_config.custom_cls.get("path", None) is not None + ): from verl.utils.import_utils import load_extern_type - dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + dataset_cls = load_extern_type( + data_config.custom_cls.path, data_config.custom_cls.name + ) if not issubclass(dataset_cls, Dataset): raise TypeError( f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' " @@ -234,7 +264,9 @@ def create_rl_sampler(data_config, dataset): if data_config.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(data_config.get("seed", 1)) - sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + sampler = RandomSampler( + data_source=dataset, generator=train_dataloader_generator + ) else: sampler = SequentialSampler(data_source=dataset) diff --git a/Agent0/executor_train/verl/recipe/entropy/reward.py b/Agent0/executor_train/verl/recipe/entropy/reward.py index 36b8b65..38f5dae 100644 --- a/Agent0/executor_train/verl/recipe/entropy/reward.py +++ b/Agent0/executor_train/verl/recipe/entropy/reward.py @@ -59,9 +59,13 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): if sandbox_url: sandbox_manager = multiprocessing.Manager() # Create a semaphore to control concurrent access to the sandbox - _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) + _concurrent_semaphore = sandbox_manager.Semaphore( + sandbox_config.get("max_concurrent", 64) + ) final_compute_score = partial( - _default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore + _default_compute_score, + sandbox_fusion_url=sandbox_url, + concurrent_semaphore=_concurrent_semaphore, ) else: final_compute_score = _default_compute_score @@ -82,5 +86,7 @@ def compute_reward_async(data: DataProto, config, tokenizer): Load the reward manager and compute the reward for a batch of data. This is meant to be run in a separate Ray worker. """ - reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) return compute_reward(data, reward_fn) diff --git a/Agent0/executor_train/verl/recipe/entropy/reward_score/__init__.py b/Agent0/executor_train/verl/recipe/entropy/reward_score/__init__.py index 7224bf3..7d8d882 100644 --- a/Agent0/executor_train/verl/recipe/entropy/reward_score/__init__.py +++ b/Agent0/executor_train/verl/recipe/entropy/reward_score/__init__.py @@ -19,7 +19,12 @@ def _default_compute_score( - data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None + data_source, + solution_str, + ground_truth, + extra_info=None, + sandbox_fusion_url=None, + concurrent_semaphore=None, ): try: res = entropy_math.compute_score(solution_str, str(ground_truth)) diff --git a/Agent0/executor_train/verl/recipe/entropy/reward_score/entropy_math/__init__.py b/Agent0/executor_train/verl/recipe/entropy/reward_score/entropy_math/__init__.py index 1b2ba64..1c4239e 100644 --- a/Agent0/executor_train/verl/recipe/entropy/reward_score/entropy_math/__init__.py +++ b/Agent0/executor_train/verl/recipe/entropy/reward_score/entropy_math/__init__.py @@ -306,7 +306,11 @@ def _fix_sqrt(string): # replace tfrac and dfrac with frac string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") - string = string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge") + string = ( + string.replace("\\neq", "\\ne") + .replace("\\leq", "\\le") + .replace("\\geq", "\\ge") + ) # print(string) # remove \left and \right @@ -686,7 +690,9 @@ def is_value_equal(given_answer: str, ground_truth: str) -> bool: str_equal = ground_truth_normalized_mathd == given_answer_normalized_mathd try: - number_equal = float(ground_truth_normalized_mathd) == float(given_answer_normalized_mathd) + number_equal = float(ground_truth_normalized_mathd) == float( + given_answer_normalized_mathd + ) return str_equal or number_equal except Exception: return str_equal @@ -703,7 +709,10 @@ def _sympy_parse(expr: str): py_expr = expr.replace("^", "**") return sympy_parser.parse_expr( py_expr, - transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)), + transformations=( + sympy_parser.standard_transformations + + (sympy_parser.implicit_multiplication_application,) + ), ) @@ -971,13 +980,16 @@ def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool: given_elems = split_tuple(given_normalized) if len(ground_truth_elems) > 1 and ( - ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1] + ground_truth_normalized[0] != given_normalized[0] + or ground_truth_normalized[-1] != given_normalized[-1] ): is_correct = False elif len(ground_truth_elems) != len(given_elems): is_correct = False else: - for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True): + for ground_truth_elem, given_elem in zip( + ground_truth_elems, given_elems, strict=True + ): if _is_frac(ground_truth_elem) and _is_frac(given_elem): # if fractions aren't reduced, then shouldn't be marked as correct # so, we don't want to allow sympy.simplify in this case @@ -1013,7 +1025,9 @@ def extract_answer(passage: str) -> str: def grade(model_answer: str, gt_answer: str, fast: bool = True): if "\\boxed" in gt_answer: gt_answer = extract_answer(gt_answer) - correct = grade_answer_mathd(model_answer, gt_answer) or grade_answer_sympy(model_answer, gt_answer) + correct = grade_answer_mathd(model_answer, gt_answer) or grade_answer_sympy( + model_answer, gt_answer + ) if not fast: # This mode further uses math_verify to recall originally false positives. # Will be a bit slower, and sensitive to bad inputs. diff --git a/Agent0/executor_train/verl/recipe/entropy/reward_score/entropy_math/grader.py b/Agent0/executor_train/verl/recipe/entropy/reward_score/entropy_math/grader.py index 02507e3..47dff95 100644 --- a/Agent0/executor_train/verl/recipe/entropy/reward_score/entropy_math/grader.py +++ b/Agent0/executor_train/verl/recipe/entropy/reward_score/entropy_math/grader.py @@ -125,7 +125,8 @@ def normalize(answer, pi) -> str: # checking if answer is % or \\% and removing % if isinstance(answer, str) and ( - bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer)) + bool(re.match(r"^\d+(\.\d+)?%$", answer)) + or bool(re.match(r"^\d+(\.\d+)?\\%$", answer)) ): return answer.replace("\\%", "").replace("%", "") @@ -188,7 +189,9 @@ def math_equal( prediction = normalize(prediction, pi) reference = normalize(reference, pi) - if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases + if ( + isinstance(prediction, str) and len(prediction) > 1000 + ): # handling weird corner-cases prediction = prediction[:1000] # 0. string comparison @@ -203,7 +206,11 @@ def math_equal( prediction = is_digit(prediction)[1] reference = is_digit(reference)[1] # number questions - gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + gt_result = ( + [reference / 100, reference, reference * 100] + if include_percentage + else [reference] + ) for item in gt_result: try: if isclose(item, prediction, rel_tol=tolerance): @@ -225,8 +232,14 @@ def math_equal( prediction = format_intervals(prediction) pred_str, ref_str = prediction, reference - if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or ( - prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") + if ( + prediction.startswith("[") + and prediction.endswith("]") + and not reference.startswith("(") + ) or ( + prediction.startswith("(") + and prediction.endswith(")") + and not reference.startswith("[") ): pred_str = pred_str.strip("[]()") ref_str = ref_str.strip("[]()") @@ -263,7 +276,9 @@ def math_equal( return bool( all( [ - math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) + math_equal( + pred_parts[i], ref_parts[i], include_percentage, tolerance + ) for i in range(len(pred_parts)) ] ) @@ -295,7 +310,11 @@ def math_equal( return True except Exception: pass - elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"): + elif ( + "\begin{pmatrix}" in reference + and prediction.startswith("[") + and prediction.endswith("]") + ): if isinstance(eval(prediction), list): try: pred_matrix = eval(prediction) @@ -307,7 +326,9 @@ def math_equal( .rstrip("\end{pmatrix}") ) # noqa: B005 ref_matrix_items = ref_matrix_items.split("\\") - ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items] + ref_matrix_items = [ + row.split("&") if "&" in row else row for row in ref_matrix_items + ] if len(pred_matrix) == len(ref_matrix_items) and all( [ math_equal(pred, ref, include_percentage, tolerance) diff --git a/Agent0/executor_train/verl/recipe/genrm_remote/reward_function.py b/Agent0/executor_train/verl/recipe/genrm_remote/reward_function.py index b2d3fbc..47d3824 100644 --- a/Agent0/executor_train/verl/recipe/genrm_remote/reward_function.py +++ b/Agent0/executor_train/verl/recipe/genrm_remote/reward_function.py @@ -81,7 +81,9 @@ def compute_score(data_source, solution_str, ground_truth, extra_info): split = extra_info["split"] from verl.utils.reward_score import default_compute_score - func_rm_score = default_compute_score(data_source, solution_str, ground_truth, extra_info) + func_rm_score = default_compute_score( + data_source, solution_str, ground_truth, extra_info + ) if split == "test": return func_rm_score @@ -102,7 +104,9 @@ def compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos) for data_source, solution_str, ground_truth, extra_info in zip( data_sources, solution_strs, ground_truths, extra_infos, strict=True ): - future = executor.submit(compute_score, data_source, solution_str, ground_truth, extra_info) + future = executor.submit( + compute_score, data_source, solution_str, ground_truth, extra_info + ) futures.append(future) results = [future.result() for future in futures] diff --git a/Agent0/executor_train/verl/recipe/minicpmo/rl_dataset.py b/Agent0/executor_train/verl/recipe/minicpmo/rl_dataset.py index 5ce15fb..97ffd48 100644 --- a/Agent0/executor_train/verl/recipe/minicpmo/rl_dataset.py +++ b/Agent0/executor_train/verl/recipe/minicpmo/rl_dataset.py @@ -42,15 +42,21 @@ def build_transform(): return transforms.Compose( [ transforms.ToTensor(), - transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + transforms.Normalize( + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD + ), ] ) def build_image_bound(input_ids, tokenizer, new_schema=True, logger=None): if new_schema: - start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id) - end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id) + start_cond = (input_ids == tokenizer.im_start_id) | ( + input_ids == tokenizer.slice_start_id + ) + end_cond = (input_ids == tokenizer.im_end_id) | ( + input_ids == tokenizer.slice_end_id + ) else: start_cond = input_ids == tokenizer.im_start_id end_cond = input_ids == tokenizer.im_end_id @@ -61,7 +67,9 @@ def build_image_bound(input_ids, tokenizer, new_schema=True, logger=None): logger.error("image start token != image end tokens") raise Exception("image start token != image end tokens") if len(image_start_tokens) > 0: - image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]) + image_bound = torch.hstack( + [image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)] + ) else: image_bound = [] return image_bound @@ -92,7 +100,9 @@ def preprocess( assert "patch_size" in slice_config assert "max_slice_nums" in slice_config assert "scale_resolution" in slice_config - default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + default_image_placeholder = ( + tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + ) new_schema = False use_image_id = False if llm_type == "qwen": @@ -117,15 +127,21 @@ def preprocess( images.append(patches[i][j]) if use_image_id: image_placeholder = ( - f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder + f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + + image_placeholder ) image_id_cnt += 1 - image_placeholder += get_grid_placeholder(tokenizer, best_grid, query_nums, new_schema=new_schema) + image_placeholder += get_grid_placeholder( + tokenizer, best_grid, query_nums, new_schema=new_schema + ) image_placeholder_dict[img_name] = image_placeholder else: images.append(image) if use_image_id: - image_placeholder = f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder + image_placeholder = ( + f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + + image_placeholder + ) image_id_cnt += 1 else: image_placeholder = default_image_placeholder @@ -135,9 +151,13 @@ def preprocess( if len(images_dict) == 1 and "" in images_dict: if "" in conversations[0]["content"]: - conversations[0]["content"] = conversations[0]["content"].replace("", image_placeholder) + conversations[0]["content"] = conversations[0]["content"].replace( + "", image_placeholder + ) else: - conversations[0]["content"] = image_placeholder + "\n" + conversations[0]["content"] + conversations[0]["content"] = ( + image_placeholder + "\n" + conversations[0]["content"] + ) else: pattern = r"" new_conversations = [] @@ -157,7 +177,9 @@ def preprocess( conversations = new_conversations # TODO change role in conversation for different llm - prompt_with_chat_template = tokenizer.apply_chat_template(conversations, add_generation_prompt=True, tokenize=False) + prompt_with_chat_template = tokenizer.apply_chat_template( + conversations, add_generation_prompt=True, tokenize=False + ) input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( prompt=prompt_with_chat_template, @@ -198,7 +220,9 @@ def preprocess( return input_dict -def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False): +def slice_image( + image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False +): original_size = image.size original_width, original_height = original_size log_ratio = math.log(original_width / original_height) @@ -211,7 +235,9 @@ def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, ne if multiple <= 1 or never_split: # dont need to slice, upsample - best_size = find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True) + best_size = find_best_resize( + original_size, scale_resolution, patch_size, allow_upscale=True + ) source_image = image.resize(best_size, Image.Resampling.BICUBIC) else: candidate_split_grids_nums = [] @@ -241,7 +267,9 @@ def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, ne best_grid = grid min_error = error - refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, allow_upscale=True) + refine_size = get_refine_size( + original_size, best_grid, scale_resolution, patch_size, allow_upscale=True + ) refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) patches = split_to_patches(refine_image, best_grid) @@ -264,7 +292,9 @@ def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale= return (best_width, best_height) -def get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False): +def get_refine_size( + original_size, grid, scale_resolution, patch_size, allow_upscale=False +): width, height = original_size grid_x, grid_y = grid @@ -305,9 +335,15 @@ def split_to_patches(image, grid): def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False): if new_schema: - image_placeholder = tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end + image_placeholder = ( + tokenizer.slice_start + + tokenizer.unk_token * query_num + + tokenizer.slice_end + ) else: - image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end + image_placeholder = ( + tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end + ) cols = grid[0] rows = grid[1] @@ -320,7 +356,9 @@ def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False): if new_schema: slice_placeholder = "\n".join(slices) else: - slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end + slice_placeholder = ( + tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end + ) return slice_placeholder @@ -330,7 +368,9 @@ def reshape_by_patch(image_tensor, patch_size): :param patch_size: :return: [3, patch_size, HW/patch_size] """ - patches = torch.nn.functional.unfold(image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size)) + patches = torch.nn.functional.unfold( + image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size) + ) patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1) patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1) @@ -344,7 +384,12 @@ def init_minicpmo_config(processor, config): "patch_size": config.get("patch_size", 14), "query_nums": config.get("query_nums", 64), "slice_config": config.get( - "slice_config", {"max_slice_nums": 9, "patch_size": config.get("patch_size", 14), "scale_resolution": 448} + "slice_config", + { + "max_slice_nums": 9, + "patch_size": config.get("patch_size", 14), + "scale_resolution": 448, + }, ), "llm_type": config.get("llm_type", "qwen"), "batch_vision": config.get("batch_vision", True), @@ -353,7 +398,14 @@ def init_minicpmo_config(processor, config): def process_minicpmo_data( - row_dict, messages, tokenizer, minicpmo_config, image_key, max_prompt_length, truncation, logger + row_dict, + messages, + tokenizer, + minicpmo_config, + image_key, + max_prompt_length, + truncation, + logger, ): """Process data for MiniCPM-o model""" if len(row_dict[image_key]) == 1: @@ -379,7 +431,9 @@ def process_minicpmo_data( logger=logger, ) - raw_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + raw_prompt = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) raw_prompt = raw_prompt.replace("", "(./)") return model_inputs, multi_modal_data, raw_prompt @@ -418,7 +472,9 @@ def __init__( self.processor = processor self.config = config - self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) + self.cache_dir = os.path.expanduser( + config.get("cache_dir", "~/.cache/verl/rlhf") + ) self.prompt_key = config.get("prompt_key", "prompt") self.image_key = config.get("image_key", "images") self.video_key = config.get("video_key", "videos") @@ -428,7 +484,9 @@ def __init__( self.truncation = config.get("truncation", "error") self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) - self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) + self.num_workers = config.get( + "filter_overlong_prompts_workers", max(1, os.cpu_count() // 4) + ) self.num_workers = min(self.num_workers, os.cpu_count()) self.use_shm = config.get("use_shm", False) self.chat_template_func = config.get("chat_template_func", None) @@ -442,15 +500,21 @@ def __init__( def _download(self, use_origin_parquet=False): from verl.utils.fs import copy_to_local - data_files = self.data_files if not use_origin_parquet else self.original_data_files + data_files = ( + self.data_files if not use_origin_parquet else self.original_data_files + ) for i, parquet_file in enumerate(data_files): - self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm) + self.data_files[i] = copy_to_local( + src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm + ) def _read_files_and_tokenize(self): dataframes = [] for parquet_file in self.data_files: # read parquet files and cache - dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + dataframe = datasets.load_dataset("parquet", data_files=parquet_file)[ + "train" + ] dataframes.append(dataframe) self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) @@ -460,10 +524,14 @@ def resume_dataset_state(self): self.serialize_dataset = not hasattr(self, "original_data_files") # resume dataframe if not it's serialized in data.pt if not self.serialize_dataset: - self._download(use_origin_parquet=True) # download and resume from original parquet files + self._download( + use_origin_parquet=True + ) # download and resume from original parquet files self._read_files_and_tokenize() else: - print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") + print( + r"old dataloader ckpt file is used, please train from scratch for better ckpt performance" + ) def __len__(self): return len(self.dataframe) @@ -498,8 +566,12 @@ def __getitem__(self, item): row_dict["multi_modal_data"] = multi_modal_data row_dict["multi_modal_inputs"] = dict(model_inputs) else: - raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False) + raw_prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + model_inputs = self.tokenizer( + raw_prompt, return_tensors="pt", add_special_tokens=False + ) input_ids = model_inputs.pop("input_ids") attention_mask = model_inputs.pop("attention_mask") position_ids = compute_position_id_with_mask(attention_mask) @@ -517,9 +589,13 @@ def __getitem__(self, item): elif self.truncation == "middle": left_half = self.max_prompt_length // 2 right_half = self.max_prompt_length - left_half - raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:] + raw_prompt_ids = ( + raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:] + ) elif self.truncation == "error": - raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") + raise RuntimeError( + f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}." + ) row_dict["raw_prompt_ids"] = raw_prompt_ids # encode prompts without chat template @@ -533,10 +609,18 @@ def __getitem__(self, item): # add index for each prompt index = row_dict.get("extra_info", {}).get("index", 0) tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {}) - interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {}) - need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs) + interaction_kwargs = row_dict.get("extra_info", {}).get( + "interaction_kwargs", {} + ) + need_tools_kwargs = row_dict.get("extra_info", {}).get( + "need_tools_kwargs", self.need_tools_kwargs + ) if need_tools_kwargs and not tools_kwargs: - logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"]) + logger.warning( + "tools_kwargs is empty for index {}, data source: {}", + index, + row_dict["data_source"], + ) row_dict["index"] = index row_dict["tools_kwargs"] = tools_kwargs row_dict["interaction_kwargs"] = interaction_kwargs diff --git a/Agent0/executor_train/verl/recipe/prime/main_prime.py b/Agent0/executor_train/verl/recipe/prime/main_prime.py index 6bf7f5e..caca917 100644 --- a/Agent0/executor_train/verl/recipe/prime/main_prime.py +++ b/Agent0/executor_train/verl/recipe/prime/main_prime.py @@ -44,7 +44,9 @@ def run_prime(config, compute_score=None): if not ray.is_initialized(): # this is for local ray cluster ray.init( - runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}, + runtime_env={ + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"} + }, num_cpus=config.ray_init.num_cpus, ) @@ -60,7 +62,9 @@ def main_task(config, compute_score=None): from verl.utils.fs import copy_local_path_from_hdfs - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + pprint( + OmegaConf.to_container(config, resolve=True) + ) # resolve=True will eval symbol values OmegaConf.resolve(config) # download the checkpoint from hdfs @@ -125,12 +129,18 @@ def main_task(config, compute_score=None): reward_manager_cls = PrimeRewardManager else: raise NotImplementedError - reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) + reward_fn = reward_manager_cls( + tokenizer=tokenizer, num_examine=0, compute_score=compute_score + ) # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) + val_reward_fn = reward_manager_cls( + tokenizer=tokenizer, num_examine=1, compute_score=compute_score + ) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager = ResourcePoolManager( + resource_pool_spec=resource_pool_spec, mapping=mapping + ) trainer = RayPRIMETrainer( config=config, diff --git a/Agent0/executor_train/verl/recipe/prime/prime_core_algos.py b/Agent0/executor_train/verl/recipe/prime/prime_core_algos.py index 8256712..b5d6d66 100644 --- a/Agent0/executor_train/verl/recipe/prime/prime_core_algos.py +++ b/Agent0/executor_train/verl/recipe/prime/prime_core_algos.py @@ -18,7 +18,9 @@ import verl.utils.torch_functional as verl_F -def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config): +def compute_rloo_advantage_return( + data: verl.DataProto, response_mask: torch.Tensor, n_samples, config +): # calculate rloo reward on different reward sources, and sum again def masked_rloo(reward_tensor_original, mask_tensor): reward_tensor = reward_tensor_original.clone() @@ -26,15 +28,21 @@ def masked_rloo(reward_tensor_original, mask_tensor): for start_pos in range(0, reward_tensor.shape[0], n_samples): cur_rewards_mean = torch.cat( [ - reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True) + reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean( + dim=0, keepdim=True + ) for pos in range(start_pos, start_pos + n_samples) ], dim=0, ) cur_rewards_sum = cur_rewards_mean.sum() cur_reward_baseline = cur_rewards_sum / (n_samples - 1) - reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = ( - reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] + reward_tensor[start_pos : start_pos + n_samples][ + mask_tensor[start_pos : start_pos + n_samples] + ] = ( + reward_tensor[start_pos : start_pos + n_samples][ + mask_tensor[start_pos : start_pos + n_samples] + ] * (n_samples / (n_samples - 1)) - cur_reward_baseline ) @@ -48,7 +56,10 @@ def masked_rloo(reward_tensor_original, mask_tensor): reward_tensor = data.batch["rm_scores"] reward_mask = response_mask.bool() - reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef) + reward_tensors.append( + masked_rloo(reward_tensor, reward_mask) + * config.algorithm.reward_dpo_coef + ) if "acc" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0: reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32) @@ -56,22 +67,42 @@ def masked_rloo(reward_tensor_original, mask_tensor): prompt_ids = data.batch["prompts"] prompt_length = prompt_ids.shape[-1] - valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(-1) + valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum( + -1 + ) reward_mask[ - torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), + torch.arange( + 0, + valid_response_length.shape[0], + dtype=torch.long, + device=valid_response_length.device, + ), valid_response_length - 1, ] = True reward_tensor[ - torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), + torch.arange( + 0, + valid_response_length.shape[0], + dtype=torch.long, + device=valid_response_length.device, + ), valid_response_length - 1, ] = data.batch["acc"] - reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef) + reward_tensors.append( + masked_rloo(reward_tensor, reward_mask) + * config.algorithm.reward_gt_coef + ) final_reward_tensor = sum(reward_tensors) - returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + returns = ( + (final_reward_tensor * response_mask) + .flip(dims=[-1]) + .cumsum(dim=-1) + .flip(dims=[-1]) + ) advantages = returns.clone() advantages = verl_F.masked_whiten(advantages, response_mask) @@ -85,19 +116,25 @@ def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta): return cur_dpo_loss -def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode="none"): +def compute_detach_dpo_loss_rm( + token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode="none" +): # we always assume that the BoN size equals n_samples # mode1: use acc as rm # mode2: use Q as rm cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta other_Q = torch.zeros_like(cur_Q) for i in range(token_level_scores.shape[0]): - Q_chosen = Q_bc[i][acc_bc[i] < acc[i]] if acc[i] > 0 else Q_bc[i][acc_bc[i] > acc[i]] + Q_chosen = ( + Q_bc[i][acc_bc[i] < acc[i]] if acc[i] > 0 else Q_bc[i][acc_bc[i] > acc[i]] + ) if len(Q_chosen) > 0: other_Q[i] = Q_chosen.mean() * beta else: other_Q[i] = 0 - dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1))) + dpo_loss = -torch.log( + torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1)) + ) if bon_mode == "none": dpo_loss = dpo_loss.mean() else: @@ -105,10 +142,14 @@ def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_m n_samples = acc_bc.shape[1] if bon_mode == "bon_rm": for i in range(token_level_scores.shape[0]): - weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1) + weight[i] = n_samples * torch.pow( + (Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1 + ) elif bon_mode == "bon_acc": for i in range(token_level_scores.shape[0]): - weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1) + weight[i] = n_samples * torch.pow( + (acc_bc[i] <= acc[i]).float().mean(), n_samples - 1 + ) else: raise NotImplementedError dpo_loss = (dpo_loss * weight).sum() @@ -120,22 +161,28 @@ def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples): dpo_acc = [] for start_id in range(0, token_level_scores.shape[0], n_samples): cur_scores = ( - token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples] + token_level_scores[start_id : start_id + n_samples] + * response_mask[start_id : start_id + n_samples] ).sum(dim=1) def get_upper_triangle(tensor_x): diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0) - upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1) + upper_tri_indices = torch.triu( + torch.ones_like(diff_matrix).bool(), diagonal=1 + ) return diff_matrix[upper_tri_indices] - cur_acc_diff = get_upper_triangle(acc[start_id : start_id + n_samples]) # in range [-1,1] + cur_acc_diff = get_upper_triangle( + acc[start_id : start_id + n_samples] + ) # in range [-1,1] cur_score_diff = get_upper_triangle(cur_scores) # in R cur_score_prediction = (cur_score_diff > 0).float() # in [0,1] if cur_acc_diff.abs().sum() == 0: cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5 else: cur_acc = ( - ((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs() + ((cur_score_diff > 0) == (cur_acc_diff > 0)).float() + * cur_acc_diff.abs() ).sum() / cur_acc_diff.abs().sum() dpo_acc.append(cur_acc.unsqueeze(0)) @@ -144,4 +191,11 @@ def get_upper_triangle(tensor_x): def compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples): - return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean() + return ( + ( + torch.sign((token_level_scores * response_mask).sum(dim=-1)) + == torch.sign(acc * 2 - 1) + ) + .float() + .mean() + ) diff --git a/Agent0/executor_train/verl/recipe/prime/prime_dp_rm.py b/Agent0/executor_train/verl/recipe/prime/prime_dp_rm.py index c9cc060..4441b21 100644 --- a/Agent0/executor_train/verl/recipe/prime/prime_dp_rm.py +++ b/Agent0/executor_train/verl/recipe/prime/prime_dp_rm.py @@ -36,7 +36,13 @@ class DataParallelPRIMERewardModel: - def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer): + def __init__( + self, + config, + reward_module: nn.Module, + ref_module: nn.Module, + reward_optimizer: optim.Optimizer, + ): self.config = config self.reward_module = reward_module self.ref_module = ref_module @@ -46,7 +52,9 @@ def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, rewa self.use_fused_kernels = self.config.model.get("use_fused_kernels", False) print(f"Reward model use_fused_kernels={self.use_fused_kernels}") - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + self.ulysses_sequence_parallel_size = self.config.get( + "ulysses_sequence_parallel_size", 1 + ) def _forward_micro_batch(self, micro_batch, prompt_length): input_ids = micro_batch["input_ids"] @@ -69,12 +77,18 @@ def _forward_micro_batch(self, micro_batch, prompt_length): ).transpose(0, 1) # for compute the log_prob - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + input_ids_rmpad_rolled = torch.roll( + input_ids_rmpad, shifts=-1, dims=1 + ) # (1, total_nnz) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + input_ids_rmpad, position_ids_rmpad, pad_size = ( + ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) ) input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size @@ -101,9 +115,14 @@ def _forward_micro_batch(self, micro_batch, prompt_length): ) if self.ulysses_sequence_parallel_size > 1: - rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size) + rm_log_labels = gather_outpus_and_unpad( + rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) rm_log_labels = pad_input( - hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + hidden_states=rm_log_labels.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, ).squeeze(-1)[:, -num_actions - 1 : -1] else: @@ -124,13 +143,17 @@ def _forward_micro_batch(self, micro_batch, prompt_length): rm_log_prob = torch.nn.functional.log_softmax( rm_output_logits[:, :-1, :], dim=-1 ) # (batch_size, seq_length, vocab_size) - rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze( + rm_log_labels = rm_log_prob.gather( + dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1) + ).squeeze( -1 ) # (batch, seq_length) if self.ref_module is not None: # do not have to pad again - with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast( + device_type=get_device_name(), dtype=torch.bfloat16 + ): if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding: ref_output = self.ref_module( input_ids=input_ids_rmpad, @@ -153,7 +176,10 @@ def _forward_micro_batch(self, micro_batch, prompt_length): ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size ) ref_log_labels = pad_input( - hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + hidden_states=ref_log_labels.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, ).squeeze(-1)[:, -num_actions - 1 : -1] else: ref_output = self.ref_module( @@ -164,7 +190,9 @@ def _forward_micro_batch(self, micro_batch, prompt_length): ) if self.use_fused_kernels: - ref_log_labels = ref_output.log_probs[:, :-1] # (batch_size, seq_length) + ref_log_labels = ref_output.log_probs[ + :, :-1 + ] # (batch_size, seq_length) ref_log_labels = ref_log_labels.to(torch.float32) else: @@ -174,13 +202,17 @@ def _forward_micro_batch(self, micro_batch, prompt_length): ) # (batch_size, seq_length, vocab_size) ref_log_labels = ref_log_prob.gather( dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1) - ).squeeze(-1) # (batch, seq_length) + ).squeeze( + -1 + ) # (batch, seq_length) else: ref_log_labels = micro_batch["old_log_probs"] ref_log_labels.to(rm_log_labels.dtype) - q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:] # this is actually diff of q + q = ( + rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:] + ) # this is actually diff of q # trim unnecessary logprobs here for i in range(micro_batch["input_ids"].shape[0]): @@ -204,7 +236,9 @@ def _forward_micro_batch(self, micro_batch, prompt_length): # outcome reward to calculate V for i in range(q.shape[0]): if self.config.prime_use_gt: - q_[i, max_positions[i] - 1] = acc[i] - q_[i, : max_positions[i] - 1].sum() + q_[i, max_positions[i] - 1] = ( + acc[i] - q_[i, : max_positions[i] - 1].sum() + ) q_[i, max_positions[i] :] = 0 for t in reversed(range(num_actions)): @@ -216,10 +250,14 @@ def _forward_micro_batch(self, micro_batch, prompt_length): if self.config.prime_granularity == "token": for i in range(micro_batch["input_ids"].shape[0]): - token_level_score[i, : max_positions[i] - 1] = r[i, : max_positions[i] - 1] + token_level_score[i, : max_positions[i] - 1] = r[ + i, : max_positions[i] - 1 + ] elif self.config.prime_granularity == "whole": for i in range(micro_batch["input_ids"].shape[0]): - token_level_score[i, max_positions[i] - 1] = r[i, : max_positions[i]] + token_level_score[i, max_positions[i] - 1] = r[ + i, : max_positions[i] + ] else: raise NotImplementedError @@ -229,33 +267,52 @@ def _optimizer_step(self): assert self.config.model.optim.grad_clip is not None if isinstance(self.reward_module, FSDP): - grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip) + grad_norm = self.reward_module.clip_grad_norm_( + self.config.model.optim.grad_clip + ) else: grad_norm = torch.nn.utils.clip_grad_norm_( - self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip + self.reward_module.parameters(), + max_norm=self.config.model.optim.grad_clip, ) self.reward_optimizer.step() return grad_norm def prime_norm(self, token_level_scores): if self.config.prime_norm == "batch_norm": - reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1]) - token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6) + reverse_cumsum = torch.cumsum( + token_level_scores.flip(dims=[1]), dim=-1 + ).flip(dims=[1]) + token_level_scores = token_level_scores / ( + reverse_cumsum.abs().max() + 1e-6 + ) return token_level_scores def compute_rm_score(self, data: DataProto): self.reward_module.eval() self.ref_module.eval() micro_batch_size = data.meta_info["micro_batch_size"] - select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "acc"] + select_keys = [ + "responses", + "input_ids", + "attention_mask", + "position_ids", + "acc", + ] batch = data.select(batch_keys=select_keys).batch use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1] + prompt_length = ( + data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1] + ) if use_dynamic_bsz: # split using dynamic bsz - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + max_token_len = ( + data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + ) + micro_batches, indices = rearrange_micro_batches( + batch=batch, max_token_len=max_token_len + ) else: micro_batches = batch.split(micro_batch_size) @@ -273,7 +330,9 @@ def compute_rm_score(self, data: DataProto): if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == rm_scores.size(0), f"{len(indices)} vs. {rm_scores.size()}" + assert len(indices) == rm_scores.size( + 0 + ), f"{len(indices)} vs. {rm_scores.size()}" revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) rm_scores = rm_scores[revert_indices] @@ -293,7 +352,14 @@ def update_rm(self, data: DataProto): beta = self.config.model.get("beta_train", 0.05) - select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "acc", "prompts"] + select_keys = [ + "input_ids", + "responses", + "attention_mask", + "position_ids", + "acc", + "prompts", + ] for key in ["Q_bc", "acc_bc"]: if key in data.batch.keys(): @@ -311,11 +377,18 @@ def update_rm(self, data: DataProto): # split batch into micro_batches mini_batch = data if self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + max_token_len = ( + self.config.ppo_max_token_len_per_gpu + * self.ulysses_sequence_parallel_size + ) + micro_batches, _ = rearrange_micro_batches( + batch=mini_batch, max_token_len=max_token_len + ) else: micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu) - self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu + self.gradient_accumulation = ( + self.config.mini_batch_size // self.config.micro_batch_size_per_gpu + ) self.reward_optimizer.zero_grad() @@ -335,12 +408,19 @@ def update_rm(self, data: DataProto): q_lst.append(q.detach()) if self.config.model.loss_type == "ce": - dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta) + dpo_loss = compute_ce_dpo_loss_rm( + q, acc, response_mask=response_mask, beta=beta + ) elif self.config.model.loss_type == "dpo": # the implementation of dpo is actually detached, which means we have to know the average # value of w/l reward before the update. dpo_loss = compute_detach_dpo_loss_rm( - q, acc, Q_bc=data["Q_bc"], acc_bc=data["acc_bc"], response_mask=response_mask, beta=beta + q, + acc, + Q_bc=data["Q_bc"], + acc_bc=data["acc_bc"], + response_mask=response_mask, + beta=beta, ) elif self.config.model.loss_type == "bon_acc": # change the original distribution of each sample to BoN distribution, then update reward model diff --git a/Agent0/executor_train/verl/recipe/prime/prime_fsdp_workers.py b/Agent0/executor_train/verl/recipe/prime/prime_fsdp_workers.py index e353404..958a92f 100644 --- a/Agent0/executor_train/verl/recipe/prime/prime_fsdp_workers.py +++ b/Agent0/executor_train/verl/recipe/prime/prime_fsdp_workers.py @@ -61,28 +61,43 @@ def __init__(self, config): world_size = torch.distributed.get_world_size() fsdp_size = self.config.model.fsdp_config.fsdp_size - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + self.device_mesh = create_device_mesh( + world_size=world_size, fsdp_size=fsdp_size + ) self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + self.ulysses_sequence_parallel_size = self.config.get( + "ulysses_sequence_parallel_size", 1 + ) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + get_device_name(), + mesh_shape=(dp, self.ulysses_sequence_parallel_size), + mesh_dim_names=["dp", "sp"], ) - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager( + self.ulysses_device_mesh + ) # set FSDP offload params self._is_offload_param = self.config.model.fsdp_config.param_offload self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config - self.config.mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + self.config.mini_batch_size //= ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + ) if self.config.micro_batch_size is not None: - self.config.micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + self.config.micro_batch_size //= ( + torch.distributed.get_world_size() + // self.ulysses_sequence_parallel_size + ) self.config.micro_batch_size_per_gpu = self.config.micro_batch_size - assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0 + assert ( + self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0 + ) def _build_reward_ref_model_optimizer(self, config): # the following line is necessary @@ -96,11 +111,16 @@ def _build_reward_ref_model_optimizer(self, config): local_path = copy_local_path_from_hdfs(config.model.path) tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) - self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) + self.tokenizer = hf_tokenizer( + tokenizer_path, + trust_remote_code=config.model.get("trust_remote_code", False), + ) from omegaconf import OmegaConf - override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_config = OmegaConf.to_container( + self.config.model.get("override_config", OmegaConf.create()) + ) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, @@ -116,10 +136,14 @@ def _build_reward_ref_model_optimizer(self, config): from transformers import AutoConfig, AutoModelForCausalLM trust_remote_code = False - reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + reward_model_config = AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code + ) reward_model_config.num_labels = 1 - init_context = get_init_weight_context_manager(use_meta_tensor=not reward_model_config.tie_word_embeddings) + init_context = get_init_weight_context_manager( + use_meta_tensor=not reward_model_config.tie_word_embeddings + ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") reward_model_config.classifier_dropout = 0.0 @@ -134,7 +158,9 @@ def _build_reward_ref_model_optimizer(self, config): fused_kernel_options = config.model.get("fused_kernel_options", None) fused_kernels_backend = ( - fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + fused_kernel_options.get("impl_backend", None) + if fused_kernel_options is not None + else None ) apply_monkey_patch( @@ -149,7 +175,9 @@ def _build_reward_ref_model_optimizer(self, config): reward_module.to(torch_dtype) if config.model.get("enable_gradient_checkpointing", False): - reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + reward_module.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) if self.rank == 0: print_model_size(reward_module) @@ -158,17 +186,29 @@ def _build_reward_ref_model_optimizer(self, config): fsdp_config = self.config.model.fsdp_config mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) - reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) - buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) + param_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("param_dtype", "bf16") + ) + reduce_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("reduce_dtype", "fp32") + ) + buffer_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("buffer_dtype", "fp32") + ) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 buffer_dtype = torch.float32 - mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + mixed_precision = MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, + ) - auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy) + auto_wrap_policy = get_fsdp_wrap_policy( + module=reward_module, config=self.config.model.fsdp_config.wrap_policy + ) log_gpu_memory_usage("Before reward model FSDP", logger=None) @@ -180,7 +220,9 @@ def _build_reward_ref_model_optimizer(self, config): reward_model_config.classifier_dropout = 0.0 reward_model_config.hidden_dropout = "0" ref_module = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=copy_local_path_from_hdfs(config.model.ref_path), + pretrained_model_name_or_path=copy_local_path_from_hdfs( + config.model.ref_path + ), torch_dtype=torch_dtype, config=reward_model_config, attn_implementation="flash_attention_2", @@ -230,7 +272,9 @@ def _build_reward_ref_model_optimizer(self, config): total_steps = config.model.optim.get("total_training_steps", 0) num_warmup_steps = int(config.model.optim.get("lr_warmup_steps", -1)) if num_warmup_steps < 0: - num_warmup_steps_ratio = config.model.optim.get("lr_warmup_steps_ratio", 0.0) + num_warmup_steps_ratio = config.model.optim.get( + "lr_warmup_steps_ratio", 0.0 + ) num_warmup_steps = int(num_warmup_steps_ratio * total_steps) print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") @@ -250,9 +294,12 @@ def init_model(self): from .prime_dp_rm import DataParallelPRIMERewardModel - self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = ( - self._build_reward_ref_model_optimizer(config=self.config) - ) + ( + self.reward_module, + self.ref_module, + self.reward_optimizer, + self.reward_lr_scheduler, + ) = self._build_reward_ref_model_optimizer(config=self.config) if self._is_offload_param: offload_fsdp_model_to_cpu(self.reward_module) @@ -295,13 +342,22 @@ def compute_rm_score(self, data: DataProto): response_mask = data.batch["attention_mask"][:, prompt_length:] acc = data.batch["acc"] - dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"]) - dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"]) + dpo_acc = compute_dpo_accuracy( + rm_scores, + acc, + response_mask=response_mask, + n_samples=data.meta_info["n"], + ) + dpo_acc_abs = compute_dpo_abs_accuracy( + rm_scores, acc, response_mask, n_samples=data.meta_info["n"] + ) metrics["reward_model/dpo_acc"] = dpo_acc.detach().item() metrics["reward_model/dpo_acc_abs"] = dpo_acc_abs.detach().item() - output = DataProto.from_dict(tensors={"rm_scores": rm_scores, "q": q}, meta_info={"metrics": metrics}) + output = DataProto.from_dict( + tensors={"rm_scores": rm_scores, "q": q}, meta_info={"metrics": metrics} + ) output = self.ulysses_sharding_manager.postprocess_data(data=output) output = output.to("cpu") @@ -317,7 +373,9 @@ def update_rm(self, data: DataProto): load_fsdp_model_to_gpu(self.ref_module) load_fsdp_model_to_gpu(self.reward_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=get_device_id()) + load_fsdp_optimizer( + optimizer=self.reward_optimizer, device_id=get_device_id() + ) # perform forward computation with self.ulysses_sharding_manager: @@ -334,14 +392,21 @@ def update_rm(self, data: DataProto): acc = data.batch["acc"] dpo_acc_before = compute_dpo_accuracy( - rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"] + rm_scores, + acc, + response_mask=response_mask, + n_samples=data.meta_info["n"], + ) + dpo_acc_abs = compute_dpo_abs_accuracy( + rm_scores, acc, response_mask, n_samples=data.meta_info["n"] ) - dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"]) metrics["reward_model/dpo_acc_before"] = dpo_acc_before.detach().item() metrics["reward_model/dpo_acc_abs_before"] = dpo_acc_abs.detach().item() - output = DataProto.from_dict(tensors={"rm_scores": rm_scores}, meta_info={"metrics": metrics}) + output = DataProto.from_dict( + tensors={"rm_scores": rm_scores}, meta_info={"metrics": metrics} + ) output = self.ulysses_sharding_manager.postprocess_data(data=output) if self._is_offload_param: @@ -353,14 +418,19 @@ def update_rm(self, data: DataProto): return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + def save_checkpoint( + self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None + ): import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.reward_module) self.checkpoint_manager.save_checkpoint( - local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + local_path=local_path, + hdfs_path=hdfs_path, + global_step=global_step, + max_ckpt_to_keep=max_ckpt_to_keep, ) torch.distributed.barrier() @@ -374,7 +444,9 @@ def load_checkpoint(self, local_path, del_local_after_load=True): if self._is_offload_param: load_fsdp_model_to_gpu(self.reward_module) - self.checkpoint_manager.load_checkpoint(local_path=local_path, del_local_after_load=del_local_after_load) + self.checkpoint_manager.load_checkpoint( + local_path=local_path, del_local_after_load=del_local_after_load + ) torch.distributed.barrier() if self._is_offload_param: diff --git a/Agent0/executor_train/verl/recipe/prime/prime_ray_trainer.py b/Agent0/executor_train/verl/recipe/prime/prime_ray_trainer.py index a5ad964..5be8378 100644 --- a/Agent0/executor_train/verl/recipe/prime/prime_ray_trainer.py +++ b/Agent0/executor_train/verl/recipe/prime/prime_ray_trainer.py @@ -30,7 +30,12 @@ from verl.single_controller.ray import RayWorkerGroup from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import _compute_response_info -from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType +from verl.trainer.ppo.ray_trainer import ( + RayPPOTrainer, + ResourcePoolManager, + Role, + WorkerType, +) from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn from verl.utils.metric import reduce_metrics @@ -95,7 +100,9 @@ def compute_data_metrics(batch, use_critic=True): "critic/values/max": torch.max(valid_values).detach().item(), "critic/values/min": torch.min(valid_values).detach().item(), # vf explained var - "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)) + .detach() + .item(), } if use_critic else {} @@ -104,14 +111,20 @@ def compute_data_metrics(batch, use_critic=True): "response_length/mean": torch.mean(response_length).detach().item(), "response_length/max": torch.max(response_length).detach().item(), "response_length/min": torch.min(response_length).detach().item(), - "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + "response_length/clip_ratio": torch.mean( + torch.eq(response_length, max_response_length).float() + ) .detach() .item(), # prompt length "prompt_length/mean": torch.mean(prompt_length).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), "prompt_length/min": torch.min(prompt_length).detach().item(), - "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + "prompt_length/clip_ratio": torch.mean( + torch.eq(prompt_length, max_prompt_length).float() + ) + .detach() + .item(), } return metrics @@ -131,13 +144,18 @@ def compute_timing_metrics(batch, timing_raw): num_tokens_of_section = { "gen": num_response_tokens, - **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, + **{ + name: num_overall_tokens + for name in ["ref", "values", "adv", "update_critic", "update_actor"] + }, } return { **{f"timing_s/{name}": value for name, value in timing_raw.items()}, **{ - f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + f"timing_per_token_ms/{name}": timing_raw[name] + * 1000 + / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) }, } @@ -185,26 +203,34 @@ def _create_dataloader(self, *args, **kwargs): # TODO: we have to make sure the batch size is divisible by the dp size self.train_dataset = RLHFDataset( - data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data + data_files=self.config.data.train_files, + tokenizer=self.tokenizer, + config=self.config.data, ) # use sampler for better ckpt resume if self.config.data.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(self.config.data.get("seed", 1)) - sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) + sampler = RandomSampler( + data_source=self.train_dataset, generator=train_dataloader_generator + ) else: sampler = SequentialSampler(data_source=self.train_dataset) self.train_dataloader = DataLoader( dataset=self.train_dataset, - batch_size=int(self.config.data.train_batch_size * self.config.data.oversample_factor), + batch_size=int( + self.config.data.train_batch_size * self.config.data.oversample_factor + ), drop_last=True, collate_fn=collate_fn, sampler=sampler, ) self.val_dataset = RLHFDataset( - data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data + data_files=self.config.data.val_files, + tokenizer=self.tokenizer, + config=self.config.data, ) self.val_dataloader = DataLoader( dataset=self.val_dataset, @@ -221,7 +247,9 @@ def _create_dataloader(self, *args, **kwargs): print(f"Size of val dataloader: {len(self.val_dataloader)}") # inject total_training_steps to actor/critic optim_config. This is hacky. - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + total_training_steps = ( + len(self.train_dataloader) * self.config.trainer.total_epochs + ) if self.config.trainer.total_training_steps is not None: total_training_steps = self.config.trainer.total_training_steps @@ -231,7 +259,9 @@ def _create_dataloader(self, *args, **kwargs): OmegaConf.set_struct(self.config, True) with open_dict(self.config): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.actor_rollout_ref.actor.optim.total_training_steps = ( + total_training_steps + ) self.config.critic.optim.total_training_steps = total_training_steps def _save_checkpoint(self): @@ -245,7 +275,11 @@ def _save_checkpoint(self): actor_remote_path = ( None if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + else os.path.join( + self.config.trainer.default_hdfs_dir, + f"global_step_{self.global_steps}", + "actor", + ) ) self.actor_rollout_wg.save_checkpoint( actor_local_path, @@ -258,7 +292,11 @@ def _save_checkpoint(self): reward_remote_path = ( None if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "reward") + else os.path.join( + self.config.trainer.default_hdfs_dir, + f"global_step_{self.global_steps}", + "reward", + ) ) self.rm_wg.save_checkpoint( reward_local_path, @@ -287,11 +325,15 @@ def _load_checkpoint(self): if self.config.trainer.default_hdfs_dir is not None: NotImplementedError("load from hdfs is not implemented yet") else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + checkpoint_folder = ( + self.config.trainer.default_local_dir + ) # TODO: check path if not os.path.isabs(checkpoint_folder): working_dir = os.getcwd() checkpoint_folder = os.path.join(working_dir, checkpoint_folder) - global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + global_step_folder = find_latest_ckpt_path( + checkpoint_folder + ) # None if no latest # find global_step_folder if self.config.trainer.resume_mode == "auto": @@ -300,10 +342,12 @@ def _load_checkpoint(self): return 0 else: if self.config.trainer.resume_mode == "resume_path": - assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert "global_step_" in self.config.trainer.resume_from_path, ( - "resume ckpt must specify the global_steps" - ) + assert isinstance( + self.config.trainer.resume_from_path, str + ), "resume ckpt must be str type" + assert ( + "global_step_" in self.config.trainer.resume_from_path + ), "resume ckpt must specify the global_steps" global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() @@ -319,11 +363,15 @@ def _load_checkpoint(self): reward_path = os.path.join(global_step_folder, "reward") # load actor self.actor_rollout_wg.load_checkpoint( - actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + actor_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load, ) # load rm if self.use_rm: - self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + self.rm_wg.load_checkpoint( + reward_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load, + ) # load dataloader, # TODO: from remote not implemented yet @@ -356,7 +404,9 @@ def fit(self): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + if self.val_reward_fn is not None and self.config.trainer.get( + "val_before_train", True + ): val_metrics = self._validate() assert val_metrics, f"{val_metrics=}" pprint(f"Initial validation metrics: {val_metrics}") @@ -375,13 +425,20 @@ def fit(self): batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation - gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + gen_batch = batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"] + ) + gen_batch = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) with simple_timer("step", timing_raw): # generate a batch with simple_timer("gen", timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + gen_batch_output = self.actor_rollout_wg.generate_sequences( + gen_batch + ) timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) @@ -389,7 +446,11 @@ def fit(self): with simple_timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + gen_baseline_output = ( + self.actor_rollout_wg.generate_sequences( + gen_baseline_batch + ) + ) batch = batch.union(gen_baseline_output) reward_baseline_tensor = self.reward_fn(batch) @@ -402,10 +463,14 @@ def fit(self): del gen_baseline_batch, gen_baseline_output batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + [str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object, ) # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) batch = batch.union(gen_batch_output) # Balance the number of valid tokens across DP ranks. @@ -417,7 +482,9 @@ def fit(self): self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum( + batch.batch["attention_mask"], dim=-1 + ).tolist() # verify with simple_timer("verify", timing_raw): @@ -436,9 +503,17 @@ def fit(self): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = compute_response_mask(batch) - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + loss_agg_mode = ( + self.config.actor_rollout_ref.actor.loss_agg_mode + ) + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=loss_agg_mode, + ) + old_log_prob_metrics = { + "actor/entropy": entropy_agg.detach().item() + } metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) @@ -446,20 +521,30 @@ def fit(self): if self.use_reference_policy: # compute reference log_prob with simple_timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob( + batch + ) batch = batch.union(ref_log_prob) with simple_timer("adv", timing_raw): if self.use_rm: - update_style = self.config.reward_model.model.get("update", "none") + update_style = self.config.reward_model.model.get( + "update", "none" + ) if update_style == "none": # only run forward reward_output = self.rm_wg.compute_rm_score(batch) - elif update_style == "after": # update and directly return the reward + elif ( + update_style == "after" + ): # update and directly return the reward reward_output = self.rm_wg.update_rm(batch) - elif update_style == "before": # update reward model, and then run forward + elif ( + update_style == "before" + ): # update reward model, and then run forward reward_output = self.rm_wg.update_rm(batch) if "metrics" in reward_output.meta_info.keys(): - reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"]) + reward_output_metrics = reduce_metrics( + reward_output.meta_info["metrics"] + ) metrics.update(reward_output_metrics) reward_output = self.rm_wg.compute_rm_score(batch) @@ -489,18 +574,24 @@ def fit(self): raise NotImplementedError batch = batch.union(reward_output) if "metrics" in reward_output.meta_info.keys(): - reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"]) + reward_output_metrics = reduce_metrics( + reward_output.meta_info["metrics"] + ) metrics.update(reward_output_metrics) # compute advantages, executed on the driver process batch = compute_advantage( - batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config + batch, + adv_estimator=self.config.algorithm.adv_estimator, + config=self.config, ) # update actor with simple_timer("update_actor", timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + actor_output_metrics = reduce_metrics( + actor_output.meta_info["metrics"] + ) metrics.update(actor_output_metrics) # validate @@ -513,13 +604,20 @@ def fit(self): val_metrics: dict = self._validate() metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0: + if ( + self.config.trainer.save_freq > 0 + and self.global_steps % self.config.trainer.save_freq == 0 + ): with simple_timer("save_checkpoint", timing_raw): self._save_checkpoint() # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + metrics.update( + compute_data_metrics(batch=batch, use_critic=self.use_critic) + ) + metrics.update( + compute_timing_metrics(batch=batch, timing_raw=timing_raw) + ) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) @@ -564,10 +662,15 @@ def filter_and_downsample(self, scores, batch: DataProto): .reshape(-1, n_samples) ) length_tensor = torch.max(length_matrix, dim=-1)[0] - filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False + filter_mask[length_tensor >= self.config.data.max_response_length - 1] = ( + False + ) reorder_index = torch.argsort(filter_mask, descending=True) - reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1) + reorder_index = ( + reorder_index.unsqueeze(-1) * n_samples + + torch.arange(0, n_samples).unsqueeze(0) + ).view(-1) batch.reorder( reorder_index[: int(len(batch) // self.config.data.oversample_factor)] ) # this operation is inplace diff --git a/Agent0/executor_train/verl/recipe/r1/data_process.py b/Agent0/executor_train/verl/recipe/r1/data_process.py index fb41c81..0b8aa9c 100644 --- a/Agent0/executor_train/verl/recipe/r1/data_process.py +++ b/Agent0/executor_train/verl/recipe/r1/data_process.py @@ -44,9 +44,15 @@ def process_aime2024(example): print(f"Loading the {data_source} dataset from huggingface...", flush=True) dataset = load_dataset(data_source, split="train") map_fn = partial( - example_map_fn, process_fn=process_aime2024, data_source=data_source, ability="English", split="test" + example_map_fn, + process_fn=process_aime2024, + data_source=data_source, + ability="English", + split="test", + ) + dataset = dataset.map( + map_fn, with_indices=True, remove_columns=dataset.column_names ) - dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) return dataset @@ -60,12 +66,20 @@ def build_gpqa_dimond_dataset(): ) def process_gpqa_diamond(example): - choices = [example["Incorrect Answer 1"], example["Incorrect Answer 2"], example["Incorrect Answer 3"]] + choices = [ + example["Incorrect Answer 1"], + example["Incorrect Answer 2"], + example["Incorrect Answer 3"], + ] random.shuffle(choices) gold_index = random.randint(0, 3) choices.insert(gold_index, example["Correct Answer"]) query_prompt = GPQA_QUERY_TEMPLATE.format( - A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=example["Question"] + A=choices[0], + B=choices[1], + C=choices[2], + D=choices[3], + Question=example["Question"], ) gold_choice = "ABCD"[gold_index] return query_prompt, gold_choice @@ -75,9 +89,15 @@ def process_gpqa_diamond(example): dataset = load_dataset(data_source, "gpqa_diamond", split="train") map_fn = partial( - example_map_fn, process_fn=process_gpqa_diamond, data_source=data_source, ability="Math", split="test" + example_map_fn, + process_fn=process_gpqa_diamond, + data_source=data_source, + ability="Math", + split="test", + ) + dataset = dataset.map( + map_fn, with_indices=True, remove_columns=dataset.column_names ) - dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) return dataset @@ -90,15 +110,27 @@ def process_cnmo2024(example): dataset_en = load_dataset(data_source, "v202412_CNMO_en", split="test") map_fn_en = partial( - example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_en", ability="Math", split="test" + example_map_fn, + process_fn=process_cnmo2024, + data_source="opencompass/cnmo2024_en", + ability="Math", + split="test", + ) + dataset_en = dataset_en.map( + map_fn_en, with_indices=True, remove_columns=dataset_en.column_names ) - dataset_en = dataset_en.map(map_fn_en, with_indices=True, remove_columns=dataset_en.column_names) dataset_zh = load_dataset(data_source, "v202412_CNMO_cn", split="test") map_fn_zh = partial( - example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_zh", ability="Math", split="test" + example_map_fn, + process_fn=process_cnmo2024, + data_source="opencompass/cnmo2024_zh", + ability="Math", + split="test", + ) + dataset_zh = dataset_zh.map( + map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names ) - dataset_zh = dataset_zh.map(map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names) dataset = concatenate_datasets([dataset_en, dataset_zh]) return dataset @@ -137,7 +169,11 @@ def process_livecodebench(example): except Exception as e: print(f"Error loading private test cases: {e}") private_test_cases = json.loads( - pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8")))) + pickle.loads( + zlib.decompress( + base64.b64decode(example["private_test_cases"].encode("utf-8")) + ) + ) ) full_test_cases = public_test_cases + private_test_cases @@ -147,19 +183,31 @@ def process_livecodebench(example): "outputs": [t["output"] for t in full_test_cases], "fn_name": metadata.get("func_name", None), } - text_cases_compressed = base64.b64encode(zlib.compress(pickle.dumps(json.dumps(test_cases)))).decode("utf-8") + text_cases_compressed = base64.b64encode( + zlib.compress(pickle.dumps(json.dumps(test_cases))) + ).decode("utf-8") return query_prompt, text_cases_compressed data_source = "livecodebench/code_generation_lite" print(f"Loading the {data_source} dataset from huggingface...", flush=True) dataset = load_dataset(data_source, split="test") # R1 Evaluation use LiveCodeBench 24.08-25.01 - dataset = dataset.filter(lambda line: "2024-08-00T00:00:00" <= line["contest_date"] < "2025-01-00T00:00:00") + dataset = dataset.filter( + lambda line: "2024-08-00T00:00:00" + <= line["contest_date"] + < "2025-01-00T00:00:00" + ) map_fn = partial( - example_map_fn, process_fn=process_livecodebench, data_source=data_source, ability="Code", split="test" + example_map_fn, + process_fn=process_livecodebench, + data_source=data_source, + ability="Code", + split="test", ) - dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8) + dataset = dataset.map( + map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8 + ) return dataset diff --git a/Agent0/executor_train/verl/recipe/r1/main_eval.py b/Agent0/executor_train/verl/recipe/r1/main_eval.py index b9c0379..5358654 100644 --- a/Agent0/executor_train/verl/recipe/r1/main_eval.py +++ b/Agent0/executor_train/verl/recipe/r1/main_eval.py @@ -56,7 +56,8 @@ def main(config): # Create remote tasks remote_tasks = [ - process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total) + process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) + for i in range(total) ] # Process results as they come in diff --git a/Agent0/executor_train/verl/recipe/r1/reward_score.py b/Agent0/executor_train/verl/recipe/r1/reward_score.py index 2010665..c602021 100644 --- a/Agent0/executor_train/verl/recipe/r1/reward_score.py +++ b/Agent0/executor_train/verl/recipe/r1/reward_score.py @@ -14,7 +14,11 @@ def reward_func(data_source, solution_str, ground_truth, extra_info=None): - if data_source in ["Maxwell-Jia/AIME_2024", "opencompass/cnmo2024_en", "opencompass/cnmo2024_zh"]: + if data_source in [ + "Maxwell-Jia/AIME_2024", + "opencompass/cnmo2024_en", + "opencompass/cnmo2024_zh", + ]: from recipe.r1.tasks import math return math.compute_score(solution_str, ground_truth) @@ -22,7 +26,10 @@ def reward_func(data_source, solution_str, ground_truth, extra_info=None): from recipe.r1.tasks import gpqa return gpqa.compute_score(solution_str, ground_truth) - elif data_source in ["livecodebench/code_generation_lite", "livecodebench/code_generation"]: + elif data_source in [ + "livecodebench/code_generation_lite", + "livecodebench/code_generation", + ]: from recipe.r1.tasks import livecodebench return livecodebench.compute_score(solution_str, ground_truth) diff --git a/Agent0/executor_train/verl/recipe/r1/tasks/livecodebench.py b/Agent0/executor_train/verl/recipe/r1/tasks/livecodebench.py index f0cbab6..ac55e59 100644 --- a/Agent0/executor_train/verl/recipe/r1/tasks/livecodebench.py +++ b/Agent0/executor_train/verl/recipe/r1/tasks/livecodebench.py @@ -60,11 +60,15 @@ def compute_score(completion, test_cases): in_outs = json.loads(test_cases) except Exception as e: print(f"Error loading test cases: {e}") - in_outs = json.loads(pickle.loads(zlib.decompress(base64.b64decode(test_cases.encode("utf-8"))))) + in_outs = json.loads( + pickle.loads(zlib.decompress(base64.b64decode(test_cases.encode("utf-8")))) + ) success = False try: - res, metadata = check_correctness(in_outs=in_outs, generation=solution, timeout=6, debug=False) + res, metadata = check_correctness( + in_outs=in_outs, generation=solution, timeout=6, debug=False + ) success = all(map(lambda x: x is True, res)) except Exception: pass diff --git a/Agent0/executor_train/verl/recipe/r1/tasks/math.py b/Agent0/executor_train/verl/recipe/r1/tasks/math.py index 5ecde54..7d632cd 100644 --- a/Agent0/executor_train/verl/recipe/r1/tasks/math.py +++ b/Agent0/executor_train/verl/recipe/r1/tasks/math.py @@ -17,7 +17,9 @@ from math_verify.metric import math_metric from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig except ImportError: - print("To use Math-Verify, please install it first by running `pip install math-verify`.") + print( + "To use Math-Verify, please install it first by running `pip install math-verify`." + ) def compute_score(model_output: str, ground_truth: str) -> bool: diff --git a/Agent0/executor_train/verl/recipe/retool/retool.py b/Agent0/executor_train/verl/recipe/retool/retool.py index b4d6028..0c25825 100644 --- a/Agent0/executor_train/verl/recipe/retool/retool.py +++ b/Agent0/executor_train/verl/recipe/retool/retool.py @@ -32,7 +32,9 @@ def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): self.code_pattern = re.compile(r"```python(.*?)```", re.DOTALL) @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: code = parameters["code"] matches = self.code_pattern.findall(code) if matches: @@ -53,12 +55,16 @@ async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) if not isinstance(code, str): code = str(code) - result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) + result = await self.execution_pool.execute.remote( + self.execute_code, instance_id, code, timeout, language + ) # sandbox has no score or metrics, use Nones return result, None, None -answer_format = """\nThe answer format must be: \\boxed{'The final answer goes here.'}""" +answer_format = ( + """\nThe answer format must be: \\boxed{'The final answer goes here.'}""" +) class CustomRLHFDataset(RLHFDataset): @@ -72,7 +78,9 @@ def _read_files_and_tokenize(self): data_source = "/".join(parquet_file.split("/")[-2:]) if data_source in ["Maxwell-Jia/AIME_2024", "yentinglin/aime_2025"]: dataframe = dataframe.map( - self.map_fn, fn_kwargs={"data_source": data_source}, remove_columns=dataframe.column_names + self.map_fn, + fn_kwargs={"data_source": data_source}, + remove_columns=dataframe.column_names, ) else: dataframe = dataframe.map(self.map_fn2, num_proc=16) diff --git a/Agent0/executor_train/verl/recipe/retool/retool_multi_turn_sft_preprocess.py b/Agent0/executor_train/verl/recipe/retool/retool_multi_turn_sft_preprocess.py index 201ee68..15f3a99 100644 --- a/Agent0/executor_train/verl/recipe/retool/retool_multi_turn_sft_preprocess.py +++ b/Agent0/executor_train/verl/recipe/retool/retool_multi_turn_sft_preprocess.py @@ -37,7 +37,9 @@ def main(): shuffled_train_dataset = train_dataset.shuffle(seed=args.seed) split_idx = int(len(shuffled_train_dataset) * args.train_ratio) train_dataset = shuffled_train_dataset.select(range(split_idx)) - test_dataset = shuffled_train_dataset.select(range(split_idx, len(shuffled_train_dataset))) + test_dataset = shuffled_train_dataset.select( + range(split_idx, len(shuffled_train_dataset)) + ) # add a row to each data item that represents a unique id def make_map_fn(split): diff --git a/Agent0/executor_train/verl/recipe/retool/retool_sft_preprocess.py b/Agent0/executor_train/verl/recipe/retool/retool_sft_preprocess.py index 0a46c15..db15593 100644 --- a/Agent0/executor_train/verl/recipe/retool/retool_sft_preprocess.py +++ b/Agent0/executor_train/verl/recipe/retool/retool_sft_preprocess.py @@ -94,7 +94,12 @@ def process(row: dict, *, tools: str): start = "*user question:*" i = content.find(start) assert i != -1 - prompt = content[i + len(start) :].replace("", "").replace("", "").strip() + prompt = ( + content[i + len(start) :] + .replace("", "") + .replace("", "") + .strip() + ) messages.append( { "role": "user", diff --git a/Agent0/executor_train/verl/recipe/spin/core_algos.py b/Agent0/executor_train/verl/recipe/spin/core_algos.py index c48027e..3a7dae1 100644 --- a/Agent0/executor_train/verl/recipe/spin/core_algos.py +++ b/Agent0/executor_train/verl/recipe/spin/core_algos.py @@ -50,8 +50,14 @@ def get_kl_controller(kl_ctrl): if kl_ctrl.type == "fixed": return FixedKLController(kl_coef=kl_ctrl.kl_coef) elif kl_ctrl.type == "adaptive": - assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" - return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) + assert ( + kl_ctrl.horizon > 0 + ), f"horizon must be larger than 0. Got {kl_ctrl.horizon}" + return AdaptiveKLController( + init_kl_coef=kl_ctrl.kl_coef, + target_kl=kl_ctrl.target_kl, + horizon=kl_ctrl.horizon, + ) else: raise NotImplementedError @@ -83,7 +89,9 @@ def compute_onlinedpo_pref( f"{token_level_rewards.shape}, {response_mask.shape}" ) if token_level_rewards.shape != response_mask.shape: - raise ValueError(f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}") + raise ValueError( + f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}" + ) # 1. Calculate Sequence Scores scores = (token_level_rewards * response_mask).sum(dim=-1) @@ -99,7 +107,9 @@ def compute_onlinedpo_pref( # 3. Compare scores to find which index (0 or 1) is the winner within each pair # winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1 - winner_indices = torch.argmax(score_pairs, dim=1) # 0 if first is max, 1 if second is max + winner_indices = torch.argmax( + score_pairs, dim=1 + ) # 0 if first is max, 1 if second is max # Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max) # Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1] # print(f" Winner indices shape: {winner_indices.shape}") # [batch_size] @@ -117,7 +127,9 @@ def compute_onlinedpo_pref( winner_global_indices = (pair_indices * 2) + winner_indices # Create boolean mask - True at the winner's position - output_preference_mask = torch.zeros(full_batch_size, dtype=torch.bool, device=scores.device) + output_preference_mask = torch.zeros( + full_batch_size, dtype=torch.bool, device=scores.device + ) output_preference_mask[winner_global_indices] = True # print(f" Output preference mask shape: {output_preference_mask.shape}") # Should be [batch_size * 2] @@ -149,11 +161,16 @@ def compute_online_dpo_loss( logits = pi_logratios - ref_logratios if loss_type == "sigmoid": - losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing + losses = ( + -F.logsigmoid(beta * logits) * (1 - label_smoothing) + - F.logsigmoid(-beta * logits) * label_smoothing + ) elif loss_type == "ipo": losses = (logits - 1 / (2 * beta)) ** 2 else: - raise ValueError(f"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'.") + raise ValueError( + f"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'." + ) return losses.mean() @@ -184,7 +201,9 @@ def get_batch_logps( # Calculate per token log probability loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none") - per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + per_token_logps = -loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) per_token_logps = per_token_logps.view( shift_logits.size(0), shift_logits.size(1) ) # Reshape back to (batch_size, seq_len-1) diff --git a/Agent0/executor_train/verl/recipe/spin/dp_actor.py b/Agent0/executor_train/verl/recipe/spin/dp_actor.py index 35caa29..143641a 100644 --- a/Agent0/executor_train/verl/recipe/spin/dp_actor.py +++ b/Agent0/executor_train/verl/recipe/spin/dp_actor.py @@ -53,7 +53,9 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: self.actor_module.eval() micro_batch_size = data.meta_info["micro_batch_size"] - temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid silent error use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] @@ -63,11 +65,17 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: if has_multi_modal_inputs: num_micro_batches = data.batch.batch_size[0] // micro_batch_size non_tensor_select_keys = ["multi_modal_inputs"] - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk( + num_micro_batches + ) elif use_dynamic_bsz: # split using dynamic bsz - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + max_token_len = ( + data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + ) + micro_batches, indices = rearrange_micro_batches( + batch=batch, max_token_len=max_token_len + ) else: micro_batches = batch.split(micro_batch_size) @@ -77,13 +85,17 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} with torch.no_grad(): - _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature) + _, log_probs = self._forward_micro_batch( + micro_batch, temperature=temperature + ) log_probs_lst.append(log_probs) log_probs = torch.concat(log_probs_lst, dim=0) if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" + assert len(indices) == log_probs.size( + 0 + ), f"{len(indices)} vs. {log_probs.size()}" revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) log_probs = log_probs[revert_indices] @@ -105,8 +117,12 @@ def update_policy_dpo_with_ref(self, data: DataProto): # ... other needed tensors like chosen/rejected input_ids, attention_mask, position_ids ... # === Get PRE-CALCULATED reference log probs from input data === - reference_chosen_logps = batch_td["reference_chosen_logps"] # Should be sequence-level logps - reference_rejected_logps = batch_td["reference_rejected_logps"] # Should be sequence-level logps + reference_chosen_logps = batch_td[ + "reference_chosen_logps" + ] # Should be sequence-level logps + reference_rejected_logps = batch_td[ + "reference_rejected_logps" + ] # Should be sequence-level logps # ============================================================ # Get DPO params from meta_info @@ -115,14 +131,22 @@ def update_policy_dpo_with_ref(self, data: DataProto): loss_type = data.meta_info.get("dpo_loss_type", "sigmoid") label_smoothing = data.meta_info.get("dpo_label_smoothing", 0.0) # reference_free should now be False as we provide ref logps - reference_free = data.meta_info.get("reference_free", False) # Default False + reference_free = data.meta_info.get( + "reference_free", False + ) # Default False except KeyError as e: - print(f"ERROR: Missing required key for DPO update (in update_policy_dpo): {e}") - print(f"Available keys in data.batch: {list(batch_td.keys())}") # Debug print + print( + f"ERROR: Missing required key for DPO update (in update_policy_dpo): {e}" + ) + print( + f"Available keys in data.batch: {list(batch_td.keys())}" + ) # Debug print return {} # Return empty metrics on error except Exception as e_data: - print(f"ERROR accessing data for DPO update (in update_policy_dpo): {e_data}") + print( + f"ERROR accessing data for DPO update (in update_policy_dpo): {e_data}" + ) return {} # --- Micro-batching Setup --- @@ -130,7 +154,9 @@ def update_policy_dpo_with_ref(self, data: DataProto): if micro_batch_size is None: # Fallback or default if not set, or raise error micro_batch_size = 1 # Example fallback, adjust as needed - print(f"Warning: 'ppo_micro_batch_size_per_gpu' not set, defaulting to {micro_batch_size}") + print( + f"Warning: 'ppo_micro_batch_size_per_gpu' not set, defaulting to {micro_batch_size}" + ) # raise ValueError("Config 'ppo_micro_batch_size_per_gpu' must be set.") # Ensure chosen_input_ids exists before getting shape @@ -141,7 +167,10 @@ def update_policy_dpo_with_ref(self, data: DataProto): if bsz == 0: print("Warning: DPO batch size is 0 in update_policy_dpo. Skipping update.") - return {"actor/dpo_loss": 0.0, "actor/grad_norm": 0.0} # Return zero metrics if batch is empty + return { + "actor/dpo_loss": 0.0, + "actor/grad_norm": 0.0, + } # Return zero metrics if batch is empty num_micro_batches = math.ceil(bsz / micro_batch_size) gradient_accumulation_steps = num_micro_batches @@ -170,29 +199,45 @@ def update_policy_dpo_with_ref(self, data: DataProto): "attention_mask": batch_td["chosen_attention_mask"][start_idx:end_idx], } if "chosen_position_ids" in batch_td: - micro_batch_chosen_inputs["position_ids"] = batch_td["chosen_position_ids"][start_idx:end_idx] + micro_batch_chosen_inputs["position_ids"] = batch_td[ + "chosen_position_ids" + ][start_idx:end_idx] micro_batch_rejected_inputs = { "input_ids": batch_td["rejected_input_ids"][start_idx:end_idx], - "attention_mask": batch_td["rejected_attention_mask"][start_idx:end_idx], + "attention_mask": batch_td["rejected_attention_mask"][ + start_idx:end_idx + ], } if "rejected_position_ids" in batch_td: - micro_batch_rejected_inputs["position_ids"] = batch_td["rejected_position_ids"][start_idx:end_idx] + micro_batch_rejected_inputs["position_ids"] = batch_td[ + "rejected_position_ids" + ][start_idx:end_idx] # Determine autocast dtype - autocast_dtype = torch.bfloat16 # Or get dynamically from config/FSDP settings + autocast_dtype = ( + torch.bfloat16 + ) # Or get dynamically from config/FSDP settings # --- Autocast Forward Pass --- with torch.autocast(device_type=get_device_name(), dtype=autocast_dtype): # --- Step 1: Forward pass for CURRENT policy log probs (with grad) --- - policy_chosen_outputs = self.actor_module(**micro_batch_chosen_inputs, use_cache=False) - policy_rejected_outputs = self.actor_module(**micro_batch_rejected_inputs, use_cache=False) + policy_chosen_outputs = self.actor_module( + **micro_batch_chosen_inputs, use_cache=False + ) + policy_rejected_outputs = self.actor_module( + **micro_batch_rejected_inputs, use_cache=False + ) # --- Step 2: Calculate CURRENT policy log probs using get_batch_logps --- policy_chosen_logps = get_batch_logps( - policy_chosen_outputs.logits, micro_batch_chosen_labels, average_log_prob=False + policy_chosen_outputs.logits, + micro_batch_chosen_labels, + average_log_prob=False, ) policy_rejected_logps = get_batch_logps( - policy_rejected_outputs.logits, micro_batch_rejected_labels, average_log_prob=False + policy_rejected_outputs.logits, + micro_batch_rejected_labels, + average_log_prob=False, ) # --- Step 3: Retrieve PRE-CALCULATED reference log probs (NO grad needed) --- @@ -203,7 +248,9 @@ def update_policy_dpo_with_ref(self, data: DataProto): # --- Step 4: Calculate DPO Logits and Loss --- pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = micro_ref_chosen_logps - micro_ref_rejected_logps # Uses pre-calculated values + ref_logratios = ( + micro_ref_chosen_logps - micro_ref_rejected_logps + ) # Uses pre-calculated values logits = pi_logratios - ref_logratios # DPO logits loss = compute_online_dpo_loss( @@ -223,11 +270,19 @@ def update_policy_dpo_with_ref(self, data: DataProto): # --- Accumulate Metrics --- total_loss += loss.item() # Unscaled loss accumulated_metrics["actor/dpo_loss_batch"].append(loss.item()) - accumulated_metrics["actor/dpo_logits_batch"].append(logits.mean().item()) + accumulated_metrics["actor/dpo_logits_batch"].append( + logits.mean().item() + ) # Accumulate policy and reference log probs/ratios if needed for debugging - accumulated_metrics["actor/policy_chosen_logps_batch"].append(policy_chosen_logps.mean().item()) - accumulated_metrics["actor/policy_rejected_logps_batch"].append(policy_rejected_logps.mean().item()) - accumulated_metrics["actor/reference_chosen_logps_batch"].append(micro_ref_chosen_logps.mean().item()) + accumulated_metrics["actor/policy_chosen_logps_batch"].append( + policy_chosen_logps.mean().item() + ) + accumulated_metrics["actor/policy_rejected_logps_batch"].append( + policy_rejected_logps.mean().item() + ) + accumulated_metrics["actor/reference_chosen_logps_batch"].append( + micro_ref_chosen_logps.mean().item() + ) accumulated_metrics["actor/reference_rejected_logps_batch"].append( micro_ref_rejected_logps.mean().item() ) @@ -237,7 +292,9 @@ def update_policy_dpo_with_ref(self, data: DataProto): if scaled_loss.requires_grad: scaled_loss.backward() else: - print(f"Warning: Scaled loss at micro-batch {i} does not require grad. Skipping backward.") + print( + f"Warning: Scaled loss at micro-batch {i} does not require grad. Skipping backward." + ) # --- End Micro-batch Loop --- @@ -248,7 +305,9 @@ def update_policy_dpo_with_ref(self, data: DataProto): if num_micro_batches > 0 and bsz > 0: # Check if any processing happened metrics["actor/dpo_loss"] = total_loss / num_micro_batches metrics["actor/grad_norm"] = ( - grad_norm.item() if torch.is_tensor(grad_norm) and torch.isfinite(grad_norm) else float("inf") + grad_norm.item() + if torch.is_tensor(grad_norm) and torch.isfinite(grad_norm) + else float("inf") ) # Average other accumulated metrics for key, val_list in accumulated_metrics.items(): @@ -262,17 +321,29 @@ def update_policy_dpo_with_ref(self, data: DataProto): and "actor/reference_chosen_logps" in metrics and "actor/reference_rejected_logps" in metrics ): - policy_ratio_mean = metrics["actor/policy_chosen_logps"] - metrics["actor/policy_rejected_logps"] - ref_ratio_mean = metrics["actor/reference_chosen_logps"] - metrics["actor/reference_rejected_logps"] + policy_ratio_mean = ( + metrics["actor/policy_chosen_logps"] + - metrics["actor/policy_rejected_logps"] + ) + ref_ratio_mean = ( + metrics["actor/reference_chosen_logps"] + - metrics["actor/reference_rejected_logps"] + ) logits_mean = policy_ratio_mean - ref_ratio_mean metrics["actor/rewards_chosen"] = beta * ( - metrics["actor/policy_chosen_logps"] - metrics["actor/reference_chosen_logps"] + metrics["actor/policy_chosen_logps"] + - metrics["actor/reference_chosen_logps"] ) metrics["actor/rewards_rejected"] = beta * ( - metrics["actor/policy_rejected_logps"] - metrics["actor/reference_rejected_logps"] + metrics["actor/policy_rejected_logps"] + - metrics["actor/reference_rejected_logps"] + ) + metrics["actor/rewards_accuracies"] = float( + logits_mean > 0 + ) # Mean accuracy proxy + metrics["actor/rewards_margins"] = ( + metrics["actor/rewards_chosen"] - metrics["actor/rewards_rejected"] ) - metrics["actor/rewards_accuracies"] = float(logits_mean > 0) # Mean accuracy proxy - metrics["actor/rewards_margins"] = metrics["actor/rewards_chosen"] - metrics["actor/rewards_rejected"] else: # Handle case where no micro-batches were run (e.g., bsz=0) metrics["actor/dpo_loss"] = 0.0 diff --git a/Agent0/executor_train/verl/recipe/spin/fsdp_workers.py b/Agent0/executor_train/verl/recipe/spin/fsdp_workers.py index e8a43e0..fa237ac 100644 --- a/Agent0/executor_train/verl/recipe/spin/fsdp_workers.py +++ b/Agent0/executor_train/verl/recipe/spin/fsdp_workers.py @@ -31,7 +31,12 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_tokenizer from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device +from verl.utils.device import ( + get_device_id, + get_device_name, + get_nccl_backend, + get_torch_device, +) from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( @@ -55,10 +60,14 @@ def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + ) else: device_mesh = init_device_mesh( - get_device_name(), mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] + get_device_name(), + mesh_shape=(world_size // fsdp_size, fsdp_size), + mesh_dim_names=["ddp", "fsdp"], ) return device_mesh @@ -71,21 +80,27 @@ def get_sharding_strategy(device_mesh): elif device_mesh.ndim == 2: sharding_strategy = ShardingStrategy.HYBRID_SHARD else: - raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") + raise NotImplementedError( + f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2" + ) return sharding_strategy class SPINRolloutRefWorker(ActorRolloutRefWorker): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - from recipe.spin.dp_actor import SPINDataParallelPPOActor as DataParallelPPOActor + from recipe.spin.dp_actor import ( + SPINDataParallelPPOActor as DataParallelPPOActor, + ) # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) from omegaconf import OmegaConf - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_model_config = OmegaConf.to_container( + self.config.model.get("override_config", OmegaConf.create()) + ) use_remove_padding = self.config.model.get("use_remove_padding", False) use_fused_kernels = self.config.model.get("use_fused_kernels", False) @@ -98,19 +113,24 @@ def init_model(self): else: optim_config = None fsdp_config = OmegaConf.create() - self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( - self._build_model_optimizer( - model_path=self.config.model.path, - fsdp_config=fsdp_config, - optim_config=optim_config, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - use_fused_kernels=use_fused_kernels, - enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), - trust_remote_code=self.config.model.get("trust_remote_code", False), - use_liger=self.config.model.get("use_liger", False), - role="actor", - ) + ( + self.actor_module_fsdp, + self.actor_optimizer, + self.actor_lr_scheduler, + self.actor_model_config, + ) = self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + enable_gradient_checkpointing=self.config.model.get( + "enable_gradient_checkpointing", False + ), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", ) # get the original unwrapped module @@ -118,7 +138,9 @@ def init_model(self): if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + log_gpu_memory_usage( + "After offload actor optimizer during init", logger=logger + ) # load from checkpoint if self._is_actor or self._is_ref: OmegaConf.set_struct(self.config.actor, True) @@ -126,7 +148,9 @@ def init_model(self): self.config.actor.use_remove_padding = use_remove_padding self.config.actor.use_fused_kernels = use_fused_kernels self.actor = DataParallelPPOActor( - config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + config=self.config.actor, + actor_module=self.actor_module_fsdp, + actor_optimizer=self.actor_optimizer, ) if self._is_rollout: @@ -150,12 +174,16 @@ def init_model(self): with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding self.config.ref.use_fused_kernels = use_fused_kernels - self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) + self.ref_policy = DataParallelPPOActor( + config=self.config.ref, actor_module=self.ref_module_fsdp + ) self.checkpoint_manager = FSDPCheckpointManager( model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), checkpoint_config=self.config.actor.checkpoint, ) @@ -165,7 +193,9 @@ def init_model(self): model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), checkpoint_config=self.config.actor.checkpoint, ) @@ -205,8 +235,12 @@ def compute_log_prob(self, data: DataProto): # Support all hardwares data = data.to(get_device_id()) # we should always recompute old_log_probs when it is HybridEngine - data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info["micro_batch_size"] = ( + self.config.rollout.log_prob_micro_batch_size_per_gpu + ) + data.meta_info["max_token_len"] = ( + self.config.rollout.log_prob_max_token_len_per_gpu + ) data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob @@ -214,7 +248,8 @@ def compute_log_prob(self, data: DataProto): data = self.ulysses_sharding_manager.preprocess_data(data) output = self.actor.compute_log_prob(data=data) output = DataProto.from_dict( - tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature} + tensors={"old_log_probs": output}, + meta_info={"temperature": self.config.rollout.temperature}, ) output = self.ulysses_sharding_manager.postprocess_data(output) @@ -249,7 +284,9 @@ def update_actor_dpo(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) + load_fsdp_optimizer( + optimizer=self.actor_optimizer, device_id=get_device_id() + ) log_gpu_memory_usage("Before update policy (DPO via PPO path)", logger=logger) @@ -258,9 +295,13 @@ def update_actor_dpo(self, data: DataProto): data = self.ulysses_sharding_manager.preprocess_data(data=data) # --- Call the core update method (now containing DPO logic) --- - with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name + with Timer( + name="update_policy_dpo_via_ppo", logger=None + ) as timer: # Use a distinct timer name # Calls the modified update_policy method - metrics = self.actor.update_policy_dpo_with_ref(data=data) # <-- THIS CALLS THE MODIFIED FUNCTION + metrics = self.actor.update_policy_dpo_with_ref( + data=data + ) # <-- THIS CALLS THE MODIFIED FUNCTION delta_time = timer.last # --- Add Performance Metrics --- @@ -268,19 +309,34 @@ def update_actor_dpo(self, data: DataProto): metrics["perf/approx_tokens_processed"] = torch.sum( data.batch.get("attention_mask", torch.tensor(0)) ).item() # Approx tokens - metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) - metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) - metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + metrics["perf/max_memory_allocated_gb"] = ( + get_torch_device().max_memory_allocated() / (1024**3) + ) + metrics["perf/max_memory_reserved_gb"] = ( + get_torch_device().max_memory_reserved() / (1024**3) + ) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / ( + 1024**3 + ) global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/actor"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + estimated_flops, promised_flops = self.flops_counter.estimate_flops( + global_num_tokens, delta_time + ) + metrics["perf/mfu/actor"] = ( + estimated_flops + * self.config.ppo_epochs + / promised_flops + / self.world_size + ) # --- LR Scheduler Step --- lr = self.actor_lr_scheduler.get_last_lr()[0] metrics["actor/lr"] = lr self.actor_lr_scheduler.step() - log_gpu_memory_usage("After update policy (DPO via PPO path)", logger=logger) + log_gpu_memory_usage( + "After update policy (DPO via PPO path)", logger=logger + ) # --- Prepare Output --- output = DataProto(meta_info={"metrics": metrics}) @@ -315,17 +371,25 @@ def __init__(self, config): from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.config.model.fsdp_config.fsdp_size - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + self.device_mesh = create_device_mesh( + world_size=world_size, fsdp_size=fsdp_size + ) self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + self.ulysses_sequence_parallel_size = self.config.get( + "ulysses_sequence_parallel_size", 1 + ) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + get_device_name(), + mesh_shape=(dp, self.ulysses_sequence_parallel_size), + mesh_dim_names=["dp", "sp"], ) - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager( + self.ulysses_device_mesh + ) self.use_remove_padding = self.config.model.get("use_remove_padding", False) @@ -349,12 +413,18 @@ def _build_model(self, config): self._do_switch_chat_template = True input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer) self.input_tokenizer = hf_tokenizer( - input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) + input_tokenizer_local_path, + trust_remote_code=config.model.get("trust_remote_code", False), + ) + self.tokenizer = hf_tokenizer( + local_path, + trust_remote_code=config.model.get("trust_remote_code", False), ) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) trust_remote_code = config.model.get("trust_remote_code", False) - model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + model_config = AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code + ) model_config.num_labels = 1 # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect @@ -373,14 +443,22 @@ def _build_model(self, config): trust_remote_code=trust_remote_code, ) - if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1: + if ( + config.model.get("use_remove_padding", False) + or self.ulysses_sequence_parallel_size > 1 + ): from verl.models.transformers.monkey_patch import apply_monkey_patch - apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size) + apply_monkey_patch( + model=reward_module, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + ) reward_module.to(torch.bfloat16) - auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) + auto_wrap_policy = get_fsdp_wrap_policy( + module=reward_module, config=self.config.model.fsdp_config + ) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) @@ -407,11 +485,21 @@ def init_model(self): self.reward_module = self._build_model(config=self.config) def _forward_micro_batch(self, micro_batch): - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + from flash_attn.bert_padding import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) - from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs + from verl.utils.ulysses import ( + gather_outpus_and_unpad, + ulysses_pad_and_slice_inputs, + ) - with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast( + device_type=get_device_name(), dtype=torch.bfloat16 + ): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -425,18 +513,26 @@ def _forward_micro_batch(self, micro_batch): # unpad the position_ids to align the rotary position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices, ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + input_ids_rmpad, position_ids_rmpad, pad_size = ( + ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) ) # only pass input_ids and position_ids to enable flash_attn_varlen output = self.reward_module( - input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False, ) # prevent model thinks we are generating reward_rmpad = output.logits reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) @@ -448,10 +544,15 @@ def _forward_micro_batch(self, micro_batch): ) # pad it back - rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) + rm_score = pad_input( + reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1) else: output = self.reward_module( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, ) rm_score = output.logits # (batch_size, seq_len, 1) rm_score = rm_score.squeeze(-1) @@ -468,7 +569,9 @@ def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): position_ids = data.batch["position_ids"] response_length = data.batch["responses"].shape[-1] eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) - token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) + token_level_scores = torch.zeros_like( + attention_mask, dtype=scores.dtype + ) # (bsz, seqlen) token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores # select the response part @@ -495,7 +598,9 @@ def _switch_chat_template(self, data: DataProto): # extract response response_ids = data.batch["responses"][i] response_length = response_ids.shape[-1] - valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() + valid_response_length = data.batch["attention_mask"][i][ + -response_length: + ].sum() valid_response_ids = response_ids[:valid_response_length] # decode @@ -517,7 +622,9 @@ def _switch_chat_template(self, data: DataProto): if max_length is None: max_length = src_max_length - model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) + model_inputs = target_tokenizer( + prompt_with_chat_template, return_tensors="pt", add_special_tokens=False + ) input_ids, attention_mask = verl_F.postprocess_data( input_ids=model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], @@ -535,7 +642,11 @@ def _switch_chat_template(self, data: DataProto): rm_position_ids = compute_position_id_with_mask(rm_attention_mask) - rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} + rm_inputs = { + "input_ids": rm_input_ids, + "attention_mask": rm_attention_mask, + "position_ids": rm_position_ids, + } return DataProto.from_dict(rm_inputs) @@ -570,10 +681,17 @@ def compute_rm_score(self, data: DataProto): use_dynamic_bsz = self.config.use_dynamic_bsz if use_dynamic_bsz: - max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) + max_token_len = ( + self.config.forward_max_token_len_per_gpu + * self.ulysses_sequence_parallel_size + ) + micro_batches, indices = rearrange_micro_batches( + batch=rm_data.batch, max_token_len=max_token_len + ) else: - micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) + micro_batches = rm_data.batch.split( + self.config.micro_batch_size_per_gpu + ) output = [] for micro_batch in micro_batches: rm_score = self._forward_micro_batch(micro_batch) @@ -582,8 +700,12 @@ def compute_rm_score(self, data: DataProto): if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + assert len(indices) == scores.size( + 0 + ), f"{len(indices)} vs. {scores.size()}" + revert_indices = torch.tensor( + get_reverse_idx(indices), dtype=torch.long + ) scores = scores[revert_indices] token_level_scores = self._expand_to_token_level(data, scores) diff --git a/Agent0/executor_train/verl/recipe/spin/main_spin.py b/Agent0/executor_train/verl/recipe/spin/main_spin.py index 9a879ee..aced2e1 100644 --- a/Agent0/executor_train/verl/recipe/spin/main_spin.py +++ b/Agent0/executor_train/verl/recipe/spin/main_spin.py @@ -30,12 +30,18 @@ def main(config): def run_ppo(config) -> None: # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices # isolation, will solve in the future - os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") + os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get( + "CUDA_VISIBLE_DEVICES", "" + ) if not ray.is_initialized(): # this is for local ray cluster ray.init( runtime_env={ - "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + } } ) @@ -53,7 +59,9 @@ def run(self, config): from verl.utils.fs import copy_to_local - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + pprint( + OmegaConf.to_container(config, resolve=True) + ) # resolve=True will eval symbol values OmegaConf.resolve(config) # download the checkpoint from hdfs @@ -64,7 +72,9 @@ def run(self, config): trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + processor = hf_processor( + local_path, use_fast=True + ) # used for multimodal LLM, could be none # define worker classes if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: @@ -136,9 +146,14 @@ def run(self, config): # Note that we always use function-based RM for validation val_reward_fn = reward_manager_cls( - tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key + tokenizer=tokenizer, + num_examine=1, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key, + ) + resource_pool_manager = ResourcePoolManager( + resource_pool_spec=resource_pool_spec, mapping=mapping ) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) trainer = RaySPINTrainer( config=config, diff --git a/Agent0/executor_train/verl/recipe/spin/spin_trainer.py b/Agent0/executor_train/verl/recipe/spin/spin_trainer.py index fa435db..46db847 100644 --- a/Agent0/executor_train/verl/recipe/spin/spin_trainer.py +++ b/Agent0/executor_train/verl/recipe/spin/spin_trainer.py @@ -36,7 +36,11 @@ from verl import DataProto from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto from verl.single_controller.base import Worker -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo.metric_utils import ( compute_throughout_metrics, @@ -46,7 +50,10 @@ ) from verl.trainer.ppo.ray_trainer import Role from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path -from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.seqlen_balancing import ( + get_seqlen_balanced_partitions, + log_seqlen_unbalance, +) from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger @@ -84,7 +91,10 @@ def create_resource_pool(self): # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different # WorkerGroup for different models resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + process_on_nodes=process_on_nodes, + use_gpu=True, + max_colocate_count=1, + name_prefix=resource_pool_name, ) self.resource_pool_dict[resource_pool_name] = resource_pool @@ -96,17 +106,30 @@ def get_resource_pool(self, role: Role) -> RayResourcePool: def get_n_gpus(self) -> int: """Get the number of gpus in this cluster.""" - return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + return sum( + [ + n_gpus + for process_on_nodes in self.resource_pool_spec.values() + for n_gpus in process_on_nodes + ] + ) def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} + node_available_gpus = { + node: node_info.get("GPU", 0) + for node, node_info in node_available_resources.items() + } # check total required gpus can be satisfied total_available_gpus = sum(node_available_gpus.values()) total_required_gpus = sum( - [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] + [ + n_gpus + for process_on_nodes in self.resource_pool_spec.values() + for n_gpus in process_on_nodes + ] ) if total_available_gpus < total_required_gpus: raise ValueError( @@ -138,8 +161,12 @@ def _compute_response_info(batch: DataProto) -> dict[str, Any]: # This is simplified - real implementation might use attention masks # to get actual lengths per sample. batch_size = batch.batch.batch_size[0] - prompt_lengths_tensor = torch.full((batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device) - response_lengths_tensor = torch.full((batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device) + prompt_lengths_tensor = torch.full( + (batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device + ) + response_lengths_tensor = torch.full( + (batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device + ) # Try getting actual lengths from attention mask if possible (more accurate) if "response_mask" in batch.batch: @@ -152,7 +179,9 @@ def _compute_response_info(batch: DataProto) -> dict[str, Any]: # Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor # Fallback to using prompt shape if mask logic is complex: prompt_lengths_tensor = torch.tensor( - [batch.batch["prompts"].shape[1]] * batch_size, dtype=torch.float32, device=batch.batch.device + [batch.batch["prompts"].shape[1]] * batch_size, + dtype=torch.float32, + device=batch.batch.device, ) return { @@ -162,11 +191,21 @@ def _compute_response_info(batch: DataProto) -> dict[str, Any]: "max_prompt_length": prompt_len, # Or from config if fixed padding } except KeyError as e: - print(f"Warning: Missing key in _compute_response_info: {e}. Returning defaults.") + print( + f"Warning: Missing key in _compute_response_info: {e}. Returning defaults." + ) # Return default/dummy values if keys are missing b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1 - max_resp = batch.batch.get("responses").shape[1] if batch.batch.get("responses") is not None else 0 - max_prompt = batch.batch.get("prompts").shape[1] if batch.batch.get("prompts") is not None else 0 + max_resp = ( + batch.batch.get("responses").shape[1] + if batch.batch.get("responses") is not None + else 0 + ) + max_prompt = ( + batch.batch.get("prompts").shape[1] + if batch.batch.get("prompts") is not None + else 0 + ) return { "prompt_length": torch.zeros(b_size), "response_length": torch.zeros(b_size), @@ -187,7 +226,10 @@ def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]: metrics = {} try: # --- Scores and Rewards (from reward_fn) --- - if "token_level_scores" in batch.batch and batch.batch["token_level_scores"] is not None: + if ( + "token_level_scores" in batch.batch + and batch.batch["token_level_scores"] is not None + ): sequence_score = batch.batch["token_level_scores"].sum(-1) metrics.update( { @@ -199,7 +241,10 @@ def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]: else: print("DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.") - if "token_level_rewards" in batch.batch and batch.batch["token_level_rewards"] is not None: + if ( + "token_level_rewards" in batch.batch + and batch.batch["token_level_rewards"] is not None + ): sequence_reward = batch.batch["token_level_rewards"].sum(-1) metrics.update( { @@ -222,8 +267,13 @@ def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]: else: print("DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.") - if "rejected_logps" in batch.batch and batch.batch["rejected_logps"] is not None: - metrics["actor/rejected_logps"] = batch.batch["rejected_logps"].mean().item() + if ( + "rejected_logps" in batch.batch + and batch.batch["rejected_logps"] is not None + ): + metrics["actor/rejected_logps"] = ( + batch.batch["rejected_logps"].mean().item() + ) else: print("DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.") @@ -239,19 +289,25 @@ def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]: prompt_length = response_info["prompt_length"] response_length = response_info["response_length"] max_response_length = response_info["max_response_length"] - max_prompt_length = response_info["max_prompt_length"] # Use calculated or from config + max_prompt_length = response_info[ + "max_prompt_length" + ] # Use calculated or from config metrics.update( { "response_length/mean": torch.mean(response_length).item(), "response_length/max": torch.max(response_length).item(), "response_length/min": torch.min(response_length).item(), - "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).item(), + "response_length/clip_ratio": torch.mean( + torch.eq(response_length, max_response_length).float() + ).item(), "prompt_length/mean": torch.mean(prompt_length).item(), "prompt_length/max": torch.max(prompt_length).item(), "prompt_length/min": torch.min(prompt_length).item(), # Prompt clip ratio might need adjustment based on how max_prompt_length is defined - "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(), + "prompt_length/clip_ratio": torch.mean( + torch.eq(prompt_length, max_prompt_length).float() + ).item(), } ) @@ -265,7 +321,9 @@ def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]: return metrics -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): +def apply_kl_penalty( + data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl" +): responses = data.batch["responses"] response_length = responses.size(1) token_level_scores = data.batch["token_level_scores"] @@ -290,7 +348,10 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) data.batch["token_level_rewards"] = token_level_rewards - metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} + metrics = { + "actor/reward_kl_penalty": current_kl, + "actor/reward_kl_penalty_coeff": beta, + } return data, metrics @@ -315,18 +376,24 @@ def compute_onlineDPO_pref(data: DataProto): mask_tensor = data.batch.get("response_mask") if rewards_tensor is None or mask_tensor is None: - print(" ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!") + print( + " ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!" + ) # Handle error case - maybe return original data or raise? # Returning original data for now to potentially allow skipping return data try: - preferences = core_algos.compute_onlinedpo_pref(token_level_rewards=rewards_tensor, response_mask=mask_tensor) + preferences = core_algos.compute_onlinedpo_pref( + token_level_rewards=rewards_tensor, response_mask=mask_tensor + ) # Store the result data.batch["preferences"] = preferences except AttributeError: - print("ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!") + print( + "ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!" + ) # Assign dummy value or raise error data.batch["preferences"] = None # Indicate failure except Exception as e_pref: @@ -382,7 +449,9 @@ def __init__( assert self.hybrid_engine, "Currently, only support hybrid engine" if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + assert ( + Role.ActorRollout in role_worker_mapping + ), f"{role_worker_mapping.keys()=}" self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager @@ -396,7 +465,9 @@ def __init__( # define in-reward KL control # kl loss control currently not suppoorted if config.algorithm.use_kl_in_reward: - self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) + self.kl_ctrl_in_reward = core_algos.get_kl_controller( + config.algorithm.kl_ctrl + ) self.use_critic = False self._validate_config() @@ -408,10 +479,12 @@ def _validate_config(self): n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes # 1. Check total batch size for data correctness - real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n - assert real_train_batch_size % n_gpus == 0, ( - f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." + real_train_batch_size = ( + config.data.train_batch_size * config.actor_rollout_ref.rollout.n ) + assert ( + real_train_batch_size % n_gpus == 0 + ), f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". @@ -466,13 +539,17 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): if self.use_critic and not config.critic.use_dynamic_bsz: # Check for critic micro-batch size conflicts check_mutually_exclusive( - config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic" + config.critic.ppo_micro_batch_size, + config.critic.ppo_micro_batch_size_per_gpu, + "critic", ) # Check for reward model micro-batch size conflicts if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: check_mutually_exclusive( - config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" + config.reward_model.micro_batch_size, + config.reward_model.micro_batch_size_per_gpu, + "reward_model", ) # Actor @@ -481,15 +558,23 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): # ppo_mini_batch_size is divisible by ppo_micro_batch_size # ppo_micro_batch_size * sequence_parallel_size >= n_gpus if not config.actor_rollout_ref.actor.use_dynamic_bsz: - assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size - sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) + assert ( + config.data.train_batch_size + >= config.actor_rollout_ref.actor.ppo_mini_batch_size + ) + sp_size = config.actor_rollout_ref.actor.get( + "ulysses_sequence_parallel_size", 1 + ) if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: assert ( config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 ) - assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus + assert ( + config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size + >= n_gpus + ) assert config.actor_rollout_ref.actor.loss_agg_mode in [ "token-mean", @@ -497,7 +582,10 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): "seq-mean-token-mean", ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" - if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + if ( + config.algorithm.use_kl_in_reward + and config.actor_rollout_ref.actor.use_kl_loss + ): print("NOTICE: You have both enabled in-reward kl and kl loss.") # critic @@ -505,24 +593,30 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) if config.critic.ppo_micro_batch_size is not None: - assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 + assert ( + config.critic.ppo_mini_batch_size + % config.critic.ppo_micro_batch_size + == 0 + ) assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus # Check if use_remove_padding is enabled when using sequence parallelism for fsdp if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: if ( - config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 - or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 + config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) + > 1 + or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) + > 1 ): - assert config.actor_rollout_ref.model.use_remove_padding, ( - "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." - ) + assert ( + config.actor_rollout_ref.model.use_remove_padding + ), "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}: if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: - assert config.critic.model.use_remove_padding, ( - "When using sequence parallelism for critic, you must enable `use_remove_padding`." - ) + assert ( + config.critic.model.use_remove_padding + ), "When using sequence parallelism for critic, you must enable `use_remove_padding`." if config.data.get("val_batch_size", None) is not None: print( @@ -532,9 +626,9 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): # check eval config if config.actor_rollout_ref.rollout.val_kwargs.do_sample: - assert config.actor_rollout_ref.rollout.temperature > 0, ( - "validation gen temperature should be greater than 0 when enabling do_sample" - ) + assert ( + config.actor_rollout_ref.rollout.temperature > 0 + ), "validation gen temperature should be greater than 0 when enabling do_sample" print("[validate_config] All configuration checks passed successfully!") @@ -547,11 +641,17 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl if train_dataset is None: train_dataset = create_rl_dataset( - self.config.data.train_files, self.config.data, self.tokenizer, self.processor + self.config.data.train_files, + self.config.data, + self.tokenizer, + self.processor, ) if val_dataset is None: val_dataset = create_rl_dataset( - self.config.data.val_files, self.config.data, self.tokenizer, self.processor + self.config.data.val_files, + self.config.data, + self.tokenizer, + self.processor, ) self.train_dataset, self.val_dataset = train_dataset, val_dataset @@ -564,7 +664,9 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl self.train_dataloader = StatefulDataLoader( dataset=self.train_dataset, - batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + batch_size=self.config.data.get( + "gen_batch_size", self.config.data.train_batch_size + ), num_workers=self.config.data.get("dataloader_num_workers", 8), drop_last=True, collate_fn=collate_fn, @@ -592,7 +694,9 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl f"Size of val dataloader: {len(self.val_dataloader)}" ) - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + total_training_steps = ( + len(self.train_dataloader) * self.config.trainer.total_epochs + ) if self.config.trainer.total_training_steps is not None: total_training_steps = self.config.trainer.total_training_steps @@ -604,11 +708,15 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl OmegaConf.set_struct(self.config, True) with open_dict(self.config): if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.actor_rollout_ref.actor.optim.total_training_steps = ( + total_training_steps + ) if OmegaConf.select(self.config, "critic.optim"): self.config.critic.optim.total_training_steps = total_training_steps except Exception as e: - print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + print( + f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}" + ) def _maybe_log_val_generations(self, inputs, outputs, scores): """Log a table of validation samples to the configured logger (wandb or swanlab)""" @@ -632,7 +740,9 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): samples = samples[:generations_to_log] # Log to each configured logger - self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + self.validation_generations_logger.log( + self.config.trainer.logger, samples, self.global_steps + ) def _validate(self): data_source_lst = [] @@ -648,23 +758,32 @@ def _validate(self): # repeat test batch test_batch = test_batch.repeat( - repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, + interleave=True, ) # we only do validation on rule-based rm - if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + if ( + self.config.reward_model.enable + and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model" + ): return {} # Store original inputs input_ids = test_batch.batch["input_ids"] # TODO: Can we keep special tokens except for padding tokens? - input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + input_texts = [ + self.tokenizer.decode(ids, skip_special_tokens=True) + for ids in input_ids + ] sample_inputs.extend(input_texts) batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] if "multi_modal_inputs" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"]) + non_tensor_batch_keys_to_pop.extend( + ["multi_modal_data", "multi_modal_inputs"] + ) if "raw_prompt" in test_batch.non_tensor_batch: non_tensor_batch_keys_to_pop.append("raw_prompt") if "tools_kwargs" in test_batch.non_tensor_batch: @@ -684,19 +803,30 @@ def _validate(self): print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") # pad to be divisible by dp_size - test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor( + test_gen_batch, self.actor_rollout_wg.world_size + ) if not self.async_rollout_mode: - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences( + test_gen_batch_padded + ) else: - test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + test_output_gen_batch_padded = ( + self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + ) # unpad - test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + test_output_gen_batch = unpad_dataproto( + test_output_gen_batch_padded, pad_size=pad_size + ) print("validation generation end") # Store generated outputs output_ids = test_output_gen_batch.batch["responses"] - output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + output_texts = [ + self.tokenizer.decode(ids, skip_special_tokens=True) + for ids in output_ids + ] sample_outputs.extend(output_texts) test_batch = test_batch.union(test_output_gen_batch) @@ -712,9 +842,15 @@ def _validate(self): for key, lst in result["reward_extra_info"].items(): reward_extra_infos_dict[key].extend(lst) - data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + data_source_lst.append( + test_batch.non_tensor_batch.get( + "data_source", ["unknown"] * reward_tensor.shape[0] + ) + ) - self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + self._maybe_log_val_generations( + inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores + ) # dump generations val_data_dir = self.config.trainer.get("validation_data_dir", None) @@ -728,13 +864,19 @@ def _validate(self): ) for key_info, lst in reward_extra_infos_dict.items(): - assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + assert len(lst) == 0 or len(lst) == len( + sample_scores + ), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" data_sources = np.concatenate(data_source_lst, axis=0) print(f"DEBUG: Data sources shape: {data_sources.shape}") # Added Print - print(f"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}") # Added Print + print( + f"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}" + ) # Added Print - data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) + data_src2var2metric2val = process_validation_metrics( + data_sources, sample_inputs, reward_extra_infos_dict + ) print( f"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}" ) # Added Print @@ -742,11 +884,19 @@ def _validate(self): for data_source, var2metric2val in data_src2var2metric2val.items(): core_var = "acc" if "acc" in var2metric2val else "reward" for var_name, metric2val in var2metric2val.items(): - n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + n_max = max( + [ + int(name.split("@")[-1].split("/")[0]) + for name in metric2val.keys() + ] + ) for metric_name, metric_val in metric2val.items(): if ( (var_name == core_var) - and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and any( + metric_name.startswith(pfx) + for pfx in ["mean", "maj", "best"] + ) and (f"@{n_max}" in metric_name) ): metric_sec = "val-core" @@ -761,39 +911,54 @@ def init_workers(self): """Init resource pool and worker group""" self.resource_pool_manager.create_resource_pool() - self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + self.resource_pool_to_cls = { + pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values() + } # create actor and rollout if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + resource_pool = self.resource_pool_manager.get_resource_pool( + Role.ActorRollout + ) actor_rollout_cls = RayClassWithInitArgs( cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.actor_rollout_ref, role="actor_rollout", ) - self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + self.resource_pool_to_cls[resource_pool][ + "actor_rollout" + ] = actor_rollout_cls else: raise NotImplementedError # create critic if self.use_critic: resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) + critic_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.Critic], config=self.config.critic + ) self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls # create reference policy if needed if self.use_reference_policy: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) ref_policy_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref" + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role="ref", ) self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls # create a reward model if reward_fn is None if self.use_rm: # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + resource_pool = self.resource_pool_manager.get_resource_pool( + Role.RewardModel + ) + rm_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RewardModel], + config=self.config.reward_model, + ) self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls # initialize WorkerGroup @@ -805,8 +970,13 @@ def init_workers(self): all_wg = {} self.wg_dicts = [] wg_kwargs = {} # Setting up kwargs for RayWorkerGroup - if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: - wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if ( + OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") + is not None + ): + wg_kwargs["ray_wait_register_center_timeout"] = ( + self.config.trainer.ray_wait_register_center_timeout + ) for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) @@ -849,24 +1019,37 @@ def _save_checkpoint(self): actor_remote_path = ( None if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + else os.path.join( + self.config.trainer.default_hdfs_dir, + f"global_step_{self.global_steps}", + "actor", + ) ) - remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + remove_previous_ckpt_in_save = self.config.trainer.get( + "remove_previous_ckpt_in_save", False + ) if remove_previous_ckpt_in_save: print( "Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and " "max_critic_ckpt_to_keep=1 instead" ) max_actor_ckpt_to_keep = ( - self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + self.config.trainer.get("max_actor_ckpt_to_keep", None) + if not remove_previous_ckpt_in_save + else 1 ) max_critic_ckpt_to_keep = ( - self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + self.config.trainer.get("max_critic_ckpt_to_keep", None) + if not remove_previous_ckpt_in_save + else 1 ) self.actor_rollout_wg.save_checkpoint( - actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + actor_local_path, + actor_remote_path, + self.global_steps, + max_ckpt_to_keep=max_actor_ckpt_to_keep, ) if self.use_critic: @@ -874,10 +1057,17 @@ def _save_checkpoint(self): critic_remote_path = ( None if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + else os.path.join( + self.config.trainer.default_hdfs_dir, + f"global_step_{self.global_steps}", + "critic", + ) ) self.critic_wg.save_checkpoint( - critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + critic_local_path, + critic_remote_path, + self.global_steps, + max_ckpt_to_keep=max_critic_ckpt_to_keep, ) # save dataloader @@ -900,11 +1090,15 @@ def _load_checkpoint(self): if self.config.trainer.default_hdfs_dir is not None: raise NotImplementedError("load from hdfs is not implemented yet") else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + checkpoint_folder = ( + self.config.trainer.default_local_dir + ) # TODO: check path if not os.path.isabs(checkpoint_folder): working_dir = os.getcwd() checkpoint_folder = os.path.join(working_dir, checkpoint_folder) - global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + global_step_folder = find_latest_ckpt_path( + checkpoint_folder + ) # None if no latest # find global_step_folder if self.config.trainer.resume_mode == "auto": @@ -913,10 +1107,12 @@ def _load_checkpoint(self): return 0 else: if self.config.trainer.resume_mode == "resume_path": - assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert "global_step_" in self.config.trainer.resume_from_path, ( - "resume ckpt must specify the global_steps" - ) + assert isinstance( + self.config.trainer.resume_from_path, str + ), "resume ckpt must be str type" + assert ( + "global_step_" in self.config.trainer.resume_from_path + ), "resume ckpt must specify the global_steps" global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() @@ -932,37 +1128,49 @@ def _load_checkpoint(self): critic_path = os.path.join(global_step_folder, "critic") # load actor self.actor_rollout_wg.load_checkpoint( - actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + actor_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load, ) # load critic if self.use_critic: self.critic_wg.load_checkpoint( - critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + critic_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load, ) # load dataloader, # TODO: from remote not implemented yet dataloader_local_path = os.path.join(global_step_folder, "data.pt") if os.path.exists(dataloader_local_path): - dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + dataloader_state_dict = torch.load( + dataloader_local_path, weights_only=False + ) self.train_dataloader.load_state_dict(dataloader_state_dict) else: - print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + print( + f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch" + ) def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): """Reorder the data on single controller such that each dp rank gets similar total tokens""" attention_mask = batch.batch["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = ( + batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() + ) # (train_batch_size,) world_size = self.actor_rollout_wg.world_size global_partition_lst = get_seqlen_balanced_partitions( global_seqlen_lst, k_partitions=world_size, equal_size=True ) # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + global_idx = torch.tensor( + [j for partition in global_partition_lst for j in partition] + ) batch.reorder(global_idx) global_balance_stats = log_seqlen_unbalance( - seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + seqlen_list=global_seqlen_lst, + partitions=global_partition_lst, + prefix=logging_prefix, ) metrics.update(global_balance_stats) @@ -985,7 +1193,9 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False), + config=OmegaConf.to_container( + self.config, resolve=True, throw_on_missing=False + ), ) except Exception as e: print(f"Warning: Failed to initialize logger: {e}") @@ -993,12 +1203,16 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop self.global_steps = 0 # Load checkpoint before doing anything loaded_step = self._load_checkpoint() - self.global_steps = loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1 + self.global_steps = ( + loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1 + ) print( f"Starting Online DPO training from global step {self.global_steps}. " f"Total steps: {self.total_training_steps}" ) - print(f"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}") + print( + f"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}" + ) # Check if reference policy is configured correctly for this mode if not self.use_reference_policy: @@ -1011,7 +1225,9 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop # "and a configured reference worker.") # Perform validation before training - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + if self.val_reward_fn is not None and self.config.trainer.get( + "val_before_train", True + ): print("Running validation before Online DPO training...") val_metrics = self._validate() pprint(f"Initial validation metrics: {val_metrics}") @@ -1053,7 +1269,9 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop metrics = {} timing_raw = {} step_timer = Timer(logger=None) - ref_log_prob_computed = False # Flag to track if ref log probs were computed + ref_log_prob_computed = ( + False # Flag to track if ref log probs were computed + ) try: # Outer try-except for the whole step step_timer.start() @@ -1072,64 +1290,95 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop and ref_update_freq > 0 and self.global_steps % ref_update_freq == 0 ): - print(f"\n[Step {self.global_steps}] Updating Reference Model Weights from Actor...") + print( + f"\n[Step {self.global_steps}] Updating Reference Model Weights from Actor..." + ) try: # --- This requires careful implementation with FSDP --- # 1. Save actor state dict (potentially to CPU memory or disk) # This needs to be done collectively across actor worker ranks. # The checkpoint_manager might be adaptable, or use FSDP APIs directly. # Example placeholder using a conceptual save/load mechanism: - actor_state_path = "/tmp/actor_state_mid" # Temporary path - self.actor_rollout_wg.save_checkpoint(actor_state_path) # Adapt save logic + actor_state_path = ( + "/tmp/actor_state_mid" # Temporary path + ) + self.actor_rollout_wg.save_checkpoint( + actor_state_path + ) # Adapt save logic # 2. Load the state dict onto the reference model worker group # This also needs collective loading on the ref worker ranks. - self.ref_policy_wg.load_checkpoint(actor_state_path, None, True) # Adapt load logic + self.ref_policy_wg.load_checkpoint( + actor_state_path, None, True + ) # Adapt load logic - print(f"[Step {self.global_steps}] Reference Model Weights Updated.") + print( + f"[Step {self.global_steps}] Reference Model Weights Updated." + ) # Optionally remove the temporary state file # os.remove(actor_state_path) # Needs rank-aware removal or shared storage except Exception as sync_e: - print(f"ERROR during reference model sync at step {self.global_steps}: {sync_e}") + print( + f"ERROR during reference model sync at step {self.global_steps}: {sync_e}" + ) traceback.print_exc() # Pop keys for generation pop_batch_keys = ["input_ids", "attention_mask"] if "position_ids" in batch.batch: pop_batch_keys.append("position_ids") - pop_non_tensor_keys = ["raw_prompt_ids"] if "raw_prompt_ids" in batch.non_tensor_batch else [] + pop_non_tensor_keys = ( + ["raw_prompt_ids"] + if "raw_prompt_ids" in batch.non_tensor_batch + else [] + ) if "multi_modal_inputs" in batch.non_tensor_batch.keys(): - pop_non_tensor_keys.extend(["multi_modal_data", "multi_modal_inputs"]) + pop_non_tensor_keys.extend( + ["multi_modal_data", "multi_modal_inputs"] + ) original_non_tensor_data = batch.non_tensor_batch gen_batch = batch.pop( batch_keys=pop_batch_keys, non_tensor_batch_keys=pop_non_tensor_keys, ) gen_batch = gen_batch.repeat( - repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, ) # (Add Debug prints for gen_batch if needed) # Generate sequences (chosen/rejected pairs) with _timer("gen", timing_raw): try: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + gen_batch_output = ( + self.actor_rollout_wg.generate_sequences(gen_batch) + ) # (Add Debug prints for gen_batch_output if needed) except Exception as gen_e: - print(f"\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!") + print( + f"\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!" + ) print(gen_e) traceback.print_exc() - print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print( + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + ) step_timer.stop() continue # Combine original prompts with generated sequences - batch.non_tensor_batch = original_non_tensor_data # Restore non-tensor data + batch.non_tensor_batch = ( + original_non_tensor_data # Restore non-tensor data + ) batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object + [str(uuid.uuid4()) for _ in range(current_batch_size)], + dtype=object, + ) + batch = batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, ) - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) # (Add Debug prints after union if needed) @@ -1139,15 +1388,21 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop if self.config.trainer.balance_batch: self._balance_batch(batch, metrics=metrics) - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum( + batch.batch["attention_mask"], dim=-1 + ).tolist() # --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef # fallback) --- # Note: For pure DPO with external ref, this 'old_log_probs' might not be strictly needed # unless used for other metrics or a fallback. Keep it for now. with _timer("policy_log_prob", timing_raw): - policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch) - batch = batch.union(policy_log_prob_output) # Adds 'old_log_probs' + policy_log_prob_output = ( + self.actor_rollout_wg.compute_log_prob(batch) + ) + batch = batch.union( + policy_log_prob_output + ) # Adds 'old_log_probs' # (Debug prints for old_log_probs) # --- Compute Log Probs using the EXTERNAL Reference Model --- @@ -1156,8 +1411,8 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop # print(f"---- [Step {self.global_steps}] DEBUG DPO: Calling compute_ref_log_prob ----") try: # 'batch' contains interleaved chosen/rejected sequences - ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob( - batch + ref_log_prob_output = ( + self.ref_policy_wg.compute_ref_log_prob(batch) ) # Returns DataProto with 'ref_log_prob' batch = batch.union( ref_log_prob_output @@ -1166,7 +1421,9 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop # print(f"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: " # f"{batch.batch['ref_log_prob'].shape} ----") except Exception as ref_e: - print(f"ERROR computing reference log probs at step {self.global_steps}: {ref_e}") + print( + f"ERROR computing reference log probs at step {self.global_steps}: {ref_e}" + ) traceback.print_exc() batch.batch["ref_log_prob"] = None # Mark as failed ref_log_prob_computed = False @@ -1183,7 +1440,9 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop # ... Ensure this calculates 'token_level_rewards' or similar ... if self.use_rm: reward_tensor_rm = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor_rm) # Adds 'rm_scores' + batch = batch.union( + reward_tensor_rm + ) # Adds 'rm_scores' reward_extra_infos_dict = {} try: @@ -1192,25 +1451,40 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop # f"Using dummy rewards. ----") # Use rm_scores if available, otherwise zeros reward_tensor = batch.batch.get( - "rm_scores", torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32) + "rm_scores", + torch.zeros_like( + batch.batch["response_mask"], + dtype=torch.float32, + ), ) else: - reward_result = self.reward_fn(batch, return_dict=True) - reward_tensor = reward_result["reward_tensor"] # Final combined reward - reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) + reward_result = self.reward_fn( + batch, return_dict=True + ) + reward_tensor = reward_result[ + "reward_tensor" + ] # Final combined reward + reward_extra_infos_dict = reward_result.get( + "reward_extra_info", {} + ) except Exception: # print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. ' # f'Using dummy rewards. ----') traceback.print_exc() - reward_tensor = torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32) + reward_tensor = torch.zeros_like( + batch.batch["response_mask"], dtype=torch.float32 + ) reward_extra_infos_dict = {} # Use 'token_level_rewards' as the key for preference calculation batch.batch["token_level_rewards"] = reward_tensor if reward_extra_infos_dict: batch.non_tensor_batch.update( - {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + { + k: np.array(v) + for k, v in reward_extra_infos_dict.items() + } ) # --- Determine Preferences --- @@ -1221,40 +1495,70 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop dpo_update_batch_proto = None # Initialize with _timer("prepare_dpo_batch", timing_raw): try: - if "preferences" not in batch.batch or batch.batch["preferences"] is None: - raise ValueError("'preferences' key missing or None after compute_onlineDPO_pref.") + if ( + "preferences" not in batch.batch + or batch.batch["preferences"] is None + ): + raise ValueError( + "'preferences' key missing or None after compute_onlineDPO_pref." + ) # Check if reference log probs were computed successfully (if needed) - if self.use_reference_policy and not ref_log_prob_computed: - raise ValueError("Reference log probs required but failed to compute.") + if ( + self.use_reference_policy + and not ref_log_prob_computed + ): + raise ValueError( + "Reference log probs required but failed to compute." + ) # Check required base keys - required_keys = ["input_ids", "attention_mask", "response_mask"] + required_keys = [ + "input_ids", + "attention_mask", + "response_mask", + ] for rk in required_keys: if rk not in batch.batch or batch.batch[rk] is None: - raise KeyError(f"Required key '{rk}' missing from batch for DPO prep.") + raise KeyError( + f"Required key '{rk}' missing from batch for DPO prep." + ) - preferences_mask = batch.batch["preferences"] # Shape [batch_size * n] + preferences_mask = batch.batch[ + "preferences" + ] # Shape [batch_size * n] not_preferences_mask = ~preferences_mask # Gather Chosen/Rejected Base Tensors - chosen_input_ids = batch.batch["input_ids"][preferences_mask] - chosen_attention_mask = batch.batch["attention_mask"][preferences_mask] - rejected_input_ids = batch.batch["input_ids"][not_preferences_mask] - rejected_attention_mask = batch.batch["attention_mask"][not_preferences_mask] + chosen_input_ids = batch.batch["input_ids"][ + preferences_mask + ] + chosen_attention_mask = batch.batch["attention_mask"][ + preferences_mask + ] + rejected_input_ids = batch.batch["input_ids"][ + not_preferences_mask + ] + rejected_attention_mask = batch.batch["attention_mask"][ + not_preferences_mask + ] chosen_position_ids = ( batch.batch.get("position_ids")[preferences_mask] if "position_ids" in batch.batch else None ) rejected_position_ids = ( - batch.batch.get("position_ids")[not_preferences_mask] + batch.batch.get("position_ids")[ + not_preferences_mask + ] if "position_ids" in batch.batch else None ) # Create Labels - print("WARNING: Creating DPO labels using configured max_prompt_length...") + print( + "WARNING: Creating DPO labels using configured max_prompt_length..." + ) prompt_len = self.config.data.max_prompt_length chosen_labels = chosen_input_ids.clone() chosen_labels[:, :prompt_len] = -100 @@ -1263,15 +1567,23 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop # Calculate and Gather Reference Log Probs (Sequence Level) if self.use_reference_policy: - ref_log_prob_tensor = batch.batch["ref_log_prob"] # Token level [bsz * n, seq_len] + ref_log_prob_tensor = batch.batch[ + "ref_log_prob" + ] # Token level [bsz * n, seq_len] response_mask_full = batch.batch[ "response_mask" ] # Response mask [bsz * n, seq_len] - ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum( + ref_sequence_logps = ( + ref_log_prob_tensor * response_mask_full + ).sum( dim=-1 ) # Sequence level [bsz * n] - reference_chosen_logps = ref_sequence_logps[preferences_mask] - reference_rejected_logps = ref_sequence_logps[not_preferences_mask] + reference_chosen_logps = ref_sequence_logps[ + preferences_mask + ] + reference_rejected_logps = ref_sequence_logps[ + not_preferences_mask + ] else: # If not using external ref, DPO needs ActorAsRef logic in dp_actor # We won't add the keys here, dp_actor will handle it (or fail if not modified) @@ -1293,88 +1605,135 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop } # Conditionally add reference logps if computed if reference_chosen_logps is not None: - dpo_tensors["reference_chosen_logps"] = reference_chosen_logps + dpo_tensors["reference_chosen_logps"] = ( + reference_chosen_logps + ) if reference_rejected_logps is not None: - dpo_tensors["reference_rejected_logps"] = reference_rejected_logps + dpo_tensors["reference_rejected_logps"] = ( + reference_rejected_logps + ) # Add position ids if they exist if chosen_position_ids is not None: - dpo_tensors["chosen_position_ids"] = chosen_position_ids + dpo_tensors["chosen_position_ids"] = ( + chosen_position_ids + ) if rejected_position_ids is not None: - dpo_tensors["rejected_position_ids"] = rejected_position_ids + dpo_tensors["rejected_position_ids"] = ( + rejected_position_ids + ) # Prepare Meta Info dpo_meta = { - "dpo_beta": OmegaConf.select(self.config.algorithm, "dpo_beta", default=0.1), + "dpo_beta": OmegaConf.select( + self.config.algorithm, "dpo_beta", default=0.1 + ), "dpo_loss_type": OmegaConf.select( - self.config.algorithm, "dpo_loss_type", default="sigmoid" + self.config.algorithm, + "dpo_loss_type", + default="sigmoid", ), "dpo_label_smoothing": OmegaConf.select( - self.config.algorithm, "dpo_label_smoothing", default=0.0 + self.config.algorithm, + "dpo_label_smoothing", + default=0.0, ), "use_reference_policy": self.use_reference_policy, "reference_free": not self.use_reference_policy, # False if using external ref "global_step": self.global_steps, } - dpo_update_batch_proto = DataProto.from_dict(tensors=dpo_tensors, meta_info=dpo_meta) + dpo_update_batch_proto = DataProto.from_dict( + tensors=dpo_tensors, meta_info=dpo_meta + ) # print(f"---- [Step {self.global_steps}] DEBUG DPO: Prepared DPO Update Batch ----") # print(f" Keys: {list(dpo_update_batch_proto.batch.keys())}") # print(f" Meta Info: {dpo_meta}") except Exception as e_prep: - print(f"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}") + print( + f"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}" + ) traceback.print_exc() dpo_update_batch_proto = None # Skip update on error # --- Actor Update Step --- actor_output = None - if self.config.trainer.critic_warmup <= self.global_steps and dpo_update_batch_proto: + if ( + self.config.trainer.critic_warmup <= self.global_steps + and dpo_update_batch_proto + ): with _timer("update_actor", timing_raw): # Pass the batch containing reference log probs (if computed) # The modified update_actor_dpo expects them if reference_free=False - actor_output = self.actor_rollout_wg.update_actor_dpo(dpo_update_batch_proto) + actor_output = self.actor_rollout_wg.update_actor_dpo( + dpo_update_batch_proto + ) if actor_output and "metrics" in actor_output.meta_info: - metrics.update(reduce_metrics(actor_output.meta_info["metrics"])) + metrics.update( + reduce_metrics(actor_output.meta_info["metrics"]) + ) elif dpo_update_batch_proto is None: print( f"Skipping actor update at step {self.global_steps} due to DPO batch preparation error." ) # --- Validation and Saving --- - test_freq = OmegaConf.select(self.config.trainer, "test_freq", default=-1) + test_freq = OmegaConf.select( + self.config.trainer, "test_freq", default=-1 + ) is_last_step = self.global_steps >= self.total_training_steps if ( self.val_reward_fn is not None and test_freq > 0 and (is_last_step or self.global_steps % test_freq == 0) ): - print(f"\nRunning DPO validation at step {self.global_steps}...") + print( + f"\nRunning DPO validation at step {self.global_steps}..." + ) val_timing_raw = {} with _timer("testing", val_timing_raw): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics if val_metrics: - metrics["time/validation_run"] = val_timing_raw.get("testing", 0) + metrics["time/validation_run"] = val_timing_raw.get( + "testing", 0 + ) metrics.update(val_metrics) else: print("Validation skipped or returned no metrics.") - save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1) - if save_freq > 0 and (is_last_step or self.global_steps % save_freq == 0): - print(f"\nSaving DPO checkpoint at step {self.global_steps}...") + save_freq = OmegaConf.select( + self.config.trainer, "save_freq", default=-1 + ) + if save_freq > 0 and ( + is_last_step or self.global_steps % save_freq == 0 + ): + print( + f"\nSaving DPO checkpoint at step {self.global_steps}..." + ) with _timer("save_checkpoint", timing_raw): self._save_checkpoint() # Saves actor (and potentially critic if used elsewhere) - metrics["time/save_checkpoint"] = timing_raw.get("save_checkpoint", 0) + metrics["time/save_checkpoint"] = timing_raw.get( + "save_checkpoint", 0 + ) # --- End main step timer context --- # --- Metrics calculation AFTER the 'step' timer block --- - metrics.update(compute_dpo_data_metrics(batch=batch)) # Use DPO-specific metrics - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + metrics.update( + compute_dpo_data_metrics(batch=batch) + ) # Use DPO-specific metrics + metrics.update( + compute_timing_metrics(batch=batch, timing_raw=timing_raw) + ) n_gpus = self.resource_pool_manager.get_n_gpus() if "step" in timing_raw: - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + metrics.update( + compute_throughout_metrics( + batch=batch, timing_raw=timing_raw, n_gpus=n_gpus + ) + ) else: print( f"Warning: 'step' key missing from timing_raw at step {self.global_steps}. " @@ -1385,14 +1744,18 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop metrics["time/step"] = step_timer.last # Log metrics - log_freq = OmegaConf.select(self.config.trainer, "log_freq", default=1) + log_freq = OmegaConf.select( + self.config.trainer, "log_freq", default=1 + ) if logger and self.global_steps % log_freq == 0: log_payload = metrics.copy() # Add learning rate to log payload if actor_output and "actor/lr" in metrics: log_payload["actor/lr"] = metrics["actor/lr"] - print(f"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}") + print( + f"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}" + ) try: logger.log(data=log_payload, step=self.global_steps) except Exception as e: @@ -1407,10 +1770,14 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop progress_bar.set_postfix(postfix_metrics) except Exception as step_e: - print(f"\n!!!!!!!! ERROR DURING DPO Step {self.global_steps} !!!!!!!!") + print( + f"\n!!!!!!!! ERROR DURING DPO Step {self.global_steps} !!!!!!!!" + ) print(f"Caught Exception: {step_e}") traceback.print_exc() - print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print( + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + ) step_timer.stop() should_stop = True break @@ -1437,12 +1804,18 @@ def fit_dpo(self): # Renamed for clarity as standard PPO loop print(f"Online DPO Training finished at step {final_step}.") # Save final checkpoint save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1) - if not self.config.trainer.get("val_only", False) and (save_freq <= 0 or final_step % save_freq != 0): + if not self.config.trainer.get("val_only", False) and ( + save_freq <= 0 or final_step % save_freq != 0 + ): print(f"Saving final DPO checkpoint at step {final_step}...") self._save_checkpoint() # Final validation run - if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get("val_only", False): + if ( + self.val_reward_fn + and last_val_metrics is None + and not self.config.trainer.get("val_only", False) + ): print("Running final validation...") last_val_metrics = self._validate() if last_val_metrics and logger: diff --git a/Agent0/executor_train/verl/recipe/sppo/dp_actor.py b/Agent0/executor_train/verl/recipe/sppo/dp_actor.py index df14c0b..a6a6091 100644 --- a/Agent0/executor_train/verl/recipe/sppo/dp_actor.py +++ b/Agent0/executor_train/verl/recipe/sppo/dp_actor.py @@ -63,10 +63,19 @@ def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() - temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid slient error multi_turn = data.meta_info.get("multi_turn", False) - select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "seq_level_rewards"] + select_keys = [ + "responses", + "input_ids", + "attention_mask", + "position_ids", + "old_log_probs", + "seq_level_rewards", + ] if multi_turn: select_keys.append("loss_mask") if self.config.use_kl_loss: @@ -77,9 +86,13 @@ def update_policy(self, data: DataProto): # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size + num_mini_batches = ( + data.batch.batch_size[0] // self.config.ppo_mini_batch_size + ) non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) + dataloader = data.select(select_keys, non_tensor_select_keys).chunk( + num_mini_batches + ) else: dataloader = batch.split(self.config.ppo_mini_batch_size) @@ -90,28 +103,47 @@ def update_policy(self, data: DataProto): mini_batch = data if has_multi_modal_inputs: self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.config.ppo_mini_batch_size + // self.config.ppo_micro_batch_size_per_gpu ) - num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + num_micro_batches = ( + mini_batch.batch.batch_size[0] + // self.config.ppo_micro_batch_size_per_gpu + ) + micro_batches = data.select( + select_keys, non_tensor_select_keys + ).chunk(num_micro_batches) elif self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + max_token_len = ( + self.config.ppo_max_token_len_per_gpu + * self.ulysses_sequence_parallel_size + ) + micro_batches, _ = rearrange_micro_batches( + batch=mini_batch, max_token_len=max_token_len + ) else: self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.config.ppo_mini_batch_size + // self.config.ppo_micro_batch_size_per_gpu ) # split batch into micro_batches - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + micro_batches = mini_batch.split( + self.config.ppo_micro_batch_size_per_gpu + ) self.actor_optimizer.zero_grad() for data in micro_batches: # Support all hardwares if isinstance(data, DataProto): - data = {**data.batch.to(get_device_id()), **data.non_tensor_batch} + data = { + **data.batch.to(get_device_id()), + **data.non_tensor_batch, + } else: - data = data.to(get_device_id()) # actor device is cpu when using offload + data = data.to( + get_device_id() + ) # actor device is cpu when using offload responses = data["responses"] response_length = responses.size(1) attention_mask = data["attention_mask"] @@ -132,7 +164,9 @@ def update_policy(self, data: DataProto): if entropy_coeff != 0: calculate_entropy = True entropy, log_prob = self._forward_micro_batch( - micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy + micro_batch=data, + temperature=temperature, + calculate_entropy=calculate_entropy, ) pg_loss, log_ratios, preference = compute_sppo_loss( @@ -145,7 +179,11 @@ def update_policy(self, data: DataProto): ) if entropy_coeff != 0: - entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + entropy_loss = agg_loss( + loss_mat=entropy, + loss_mask=response_mask, + loss_agg_mode=loss_agg_mode, + ) # compute policy loss policy_loss = pg_loss - entropy_loss * entropy_coeff @@ -156,10 +194,14 @@ def update_policy(self, data: DataProto): ref_log_prob = data["ref_log_prob"] # compute kl loss kld = kl_penalty( - logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type + logprob=log_prob, + ref_logprob=ref_log_prob, + kl_penalty=self.config.kl_loss_type, ) kl_loss = agg_loss( - loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode + loss_mat=kld, + loss_mask=response_mask, + loss_agg_mode=self.config.loss_agg_mode, ) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef @@ -168,7 +210,9 @@ def update_policy(self, data: DataProto): if self.config.use_dynamic_bsz: # relative to the dynamic bsz - loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) + loss = policy_loss * ( + len(data) / self.config.ppo_mini_batch_size + ) else: loss = policy_loss / self.gradient_accumulation loss.backward() diff --git a/Agent0/executor_train/verl/recipe/sppo/main_sppo.py b/Agent0/executor_train/verl/recipe/sppo/main_sppo.py index d99f4f2..e478ad7 100644 --- a/Agent0/executor_train/verl/recipe/sppo/main_sppo.py +++ b/Agent0/executor_train/verl/recipe/sppo/main_sppo.py @@ -35,12 +35,18 @@ def main(config): def run_ppo(config) -> None: # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices # isolation, will solve in the future - os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") + os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get( + "CUDA_VISIBLE_DEVICES", "" + ) if not ray.is_initialized(): # this is for local ray cluster ray.init( runtime_env={ - "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + } }, num_cpus=config.ray_init.num_cpus, ) @@ -59,7 +65,9 @@ def run(self, config): from verl.utils.fs import copy_to_local - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + pprint( + OmegaConf.to_container(config, resolve=True) + ) # resolve=True will eval symbol values OmegaConf.resolve(config) # download the checkpoint from hdfs @@ -70,7 +78,9 @@ def run(self, config): trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + processor = hf_processor( + local_path, use_fast=True + ) # used for multimodal LLM, could be none # define worker classes if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: @@ -125,15 +135,23 @@ def run(self, config): mapping[Role.RewardModel] = global_pool_id # use reference model - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + if ( + config.algorithm.use_kl_in_reward + or config.actor_rollout_ref.actor.use_kl_loss + ): role_worker_mapping[Role.RefPolicy] = ray.remote(SPPOActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id reward_fn = load_reward_manager( - config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + config, + tokenizer, + num_examine=0, + **config.reward_model.get("reward_kwargs", {}) ) val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager = ResourcePoolManager( + resource_pool_spec=resource_pool_spec, mapping=mapping + ) trainer = RaySPPOTrainer( config=config, diff --git a/Agent0/executor_train/verl/recipe/sppo/sppo_ray_trainer.py b/Agent0/executor_train/verl/recipe/sppo/sppo_ray_trainer.py index 15e2f9c..7da13c0 100644 --- a/Agent0/executor_train/verl/recipe/sppo/sppo_ray_trainer.py +++ b/Agent0/executor_train/verl/recipe/sppo/sppo_ray_trainer.py @@ -48,7 +48,9 @@ from verl.utils.tracking import ValidationGenerationsLogger -def softmean(x: torch.Tensor, beta: float, dim: int = -1, keepdim: bool = False) -> torch.Tensor: +def softmean( + x: torch.Tensor, beta: float, dim: int = -1, keepdim: bool = False +) -> torch.Tensor: """ Compute SoftMean_ฮฒ(x) = (1/ฮฒ) * log( (1/n) * ฮฃ exp(ฮฒ * x_i) ) Falls back to arithmetic mean when ฮฒ=0. @@ -107,7 +109,9 @@ def __init__( assert self.hybrid_engine, "Currently, only support hybrid engine" if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + assert ( + Role.ActorRollout in role_worker_mapping + ), f"{role_worker_mapping.keys()=}" self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager @@ -120,7 +124,9 @@ def __init__( # define in-reward KL control # kl loss control currently not supported if config.algorithm.use_kl_in_reward: - self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) + self.kl_ctrl_in_reward = core_algos.get_kl_controller( + config.algorithm.kl_ctrl + ) self.use_critic = False @@ -152,7 +158,9 @@ def fit(self): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + if self.val_reward_fn is not None and self.config.trainer.get( + "val_before_train", True + ): val_metrics = self._validate() pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) @@ -160,7 +168,11 @@ def fit(self): return # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + progress_bar = tqdm( + total=self.total_training_steps, + initial=self.global_steps, + desc="Training Progress", + ) # we start from step 1 self.global_steps += 1 @@ -185,7 +197,10 @@ def fit(self): batch_keys=batch_keys_to_pop, non_tensor_batch_keys=non_tensor_batch_keys_to_pop, ) - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + gen_batch = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) is_last_step = self.global_steps >= self.total_training_steps @@ -193,9 +208,13 @@ def fit(self): # generate a batch with simple_timer("gen", timing_raw): if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + gen_batch_output = self.actor_rollout_wg.generate_sequences( + gen_batch + ) else: - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + gen_batch_output = ( + self.async_rollout_manager.generate_sequences(gen_batch) + ) timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) @@ -203,7 +222,11 @@ def fit(self): with simple_timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + gen_baseline_output = ( + self.actor_rollout_wg.generate_sequences( + gen_baseline_batch + ) + ) batch = batch.union(gen_baseline_output) reward_baseline_tensor = self.reward_fn(batch) @@ -216,10 +239,14 @@ def fit(self): del gen_baseline_batch, gen_baseline_output batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + [str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object, ) # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) batch = batch.union(gen_batch_output) batch.batch["response_mask"] = compute_response_mask(batch) @@ -232,7 +259,9 @@ def fit(self): self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum( + batch.batch["attention_mask"], dim=-1 + ).tolist() with simple_timer("reward", timing_raw): # compute reward model score @@ -241,9 +270,13 @@ def fit(self): batch = batch.union(reward_tensor) if self.config.reward_model.launch_reward_fn_async: - future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer) + future_reward = compute_reward_async.remote( + batch, self.config, self.tokenizer + ) else: - reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + reward_tensor, reward_extra_infos_dict = compute_reward( + batch, self.reward_fn + ) # recompute old_log_probs with simple_timer("old_log_prob", timing_raw): @@ -251,8 +284,14 @@ def fit(self): entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=loss_agg_mode, + ) + old_log_prob_metrics = { + "actor/entropy": entropy_agg.detach().item() + } metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) @@ -277,17 +316,25 @@ def fit(self): batch.batch["token_level_scores"] = reward_tensor if reward_extra_infos_dict: - batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: batch, kl_metrics = apply_kl_penalty( - batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + batch, + kl_ctrl=self.kl_ctrl_in_reward, + kl_penalty=self.config.algorithm.kl_penalty, ) metrics.update(kl_metrics) else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - batch.batch["seq_level_rewards"] = batch.batch["token_level_scores"] + batch.batch["token_level_rewards"] = batch.batch[ + "token_level_scores" + ] + batch.batch["seq_level_rewards"] = batch.batch[ + "token_level_scores" + ] beta = self.config.algorithm.sppo_eta batch = compute_advantage(batch, beta=beta) @@ -296,16 +343,22 @@ def fit(self): if self.use_critic: with simple_timer("update_critic", timing_raw): critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + critic_output_metrics = reduce_metrics( + critic_output.meta_info["metrics"] + ) metrics.update(critic_output_metrics) # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor with simple_timer("update_actor", timing_raw): - batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + batch.meta_info["multi_turn"] = ( + self.config.actor_rollout_ref.rollout.multi_turn.enable + ) actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + actor_output_metrics = reduce_metrics( + actor_output.meta_info["metrics"] + ) metrics.update(actor_output_metrics) # Log rollout generations if enabled @@ -313,9 +366,15 @@ def fit(self): if rollout_data_dir: with simple_timer("dump_rollout_generations", timing_raw): print(batch.batch.keys()) - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) - scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + inputs = self.tokenizer.batch_decode( + batch.batch["prompts"], skip_special_tokens=True + ) + outputs = self.tokenizer.batch_decode( + batch.batch["responses"], skip_special_tokens=True + ) + scores = ( + batch.batch["token_level_scores"].sum(-1).cpu().tolist() + ) self._dump_generations( inputs=inputs, outputs=outputs, @@ -328,7 +387,10 @@ def fit(self): if ( self.val_reward_fn is not None and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + and ( + is_last_step + or self.global_steps % self.config.trainer.test_freq == 0 + ) ): with simple_timer("testing", timing_raw): val_metrics: dict = self._validate() @@ -337,7 +399,8 @@ def fit(self): metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 ): with simple_timer("save_checkpoint", timing_raw): self._save_checkpoint() diff --git a/Agent0/executor_train/verl/recipe/sppo/sppo_worker.py b/Agent0/executor_train/verl/recipe/sppo/sppo_worker.py index fbe3a6e..dde1b9a 100644 --- a/Agent0/executor_train/verl/recipe/sppo/sppo_worker.py +++ b/Agent0/executor_train/verl/recipe/sppo/sppo_worker.py @@ -45,7 +45,9 @@ def init_model(self): from omegaconf import OmegaConf - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_model_config = OmegaConf.to_container( + self.config.model.get("override_config", OmegaConf.create()) + ) use_remove_padding = self.config.model.get("use_remove_padding", False) use_fused_kernels = self.config.model.get("use_fused_kernels", False) @@ -58,19 +60,24 @@ def init_model(self): else: optim_config = None fsdp_config = OmegaConf.create() - self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( - self._build_model_optimizer( - model_path=self.config.model.path, - fsdp_config=fsdp_config, - optim_config=optim_config, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - use_fused_kernels=use_fused_kernels, - enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), - trust_remote_code=self.config.model.get("trust_remote_code", False), - use_liger=self.config.model.get("use_liger", False), - role="actor", - ) + ( + self.actor_module_fsdp, + self.actor_optimizer, + self.actor_lr_scheduler, + self.actor_model_config, + ) = self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + enable_gradient_checkpointing=self.config.model.get( + "enable_gradient_checkpointing", False + ), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", ) # get the original unwrapped module @@ -78,11 +85,15 @@ def init_model(self): if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage("After offload actor model during init", logger=logger) + log_gpu_memory_usage( + "After offload actor model during init", logger=logger + ) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + log_gpu_memory_usage( + "After offload actor optimizer during init", logger=logger + ) # load from checkpoint if self._is_actor: OmegaConf.set_struct(self.config.actor, True) @@ -90,7 +101,9 @@ def init_model(self): self.config.actor.use_remove_padding = use_remove_padding self.config.actor.use_fused_kernels = use_fused_kernels self.actor = DataParallelSPPOActor( - config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + config=self.config.actor, + actor_module=self.actor_module_fsdp, + actor_optimizer=self.actor_optimizer, ) if self._is_rollout: @@ -114,7 +127,9 @@ def init_model(self): with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding self.config.ref.use_fused_kernels = use_fused_kernels - self.ref_policy = DataParallelSPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) + self.ref_policy = DataParallelSPPOActor( + config=self.config.ref, actor_module=self.ref_module_fsdp + ) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) @@ -122,6 +137,8 @@ def init_model(self): model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), checkpoint_config=self.config.actor.checkpoint, ) diff --git a/Agent0/executor_train/verl/scripts/converter_hf_to_mcore.py b/Agent0/executor_train/verl/scripts/converter_hf_to_mcore.py index b3101a6..ccb5d0b 100644 --- a/Agent0/executor_train/verl/scripts/converter_hf_to_mcore.py +++ b/Agent0/executor_train/verl/scripts/converter_hf_to_mcore.py @@ -35,11 +35,29 @@ def _init_args(): parser = argparse.ArgumentParser() - parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model") - parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model") - parser.add_argument("--use_cpu_initialization", action="store_true", help="Whether to use cpu initialization") - parser.add_argument("--test", action="store_true", help="Whether to test the conversion") - parser.add_argument("--trust_remote_code", action="store_true", help="Whether to trust remote code") + parser.add_argument( + "--hf_model_path", + type=str, + required=True, + help="The path for the huggingface model", + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="The path for the output mcore model", + ) + parser.add_argument( + "--use_cpu_initialization", + action="store_true", + help="Whether to use cpu initialization", + ) + parser.add_argument( + "--test", action="store_true", help="Whether to test the conversion" + ) + parser.add_argument( + "--trust_remote_code", action="store_true", help="Whether to trust remote code" + ) args = parser.parse_args() return args @@ -54,7 +72,9 @@ def test_conversion(megatron_model_provider, tfconfig, output_path, model): transformer_config=tfconfig, ) ref_state_dict = model_test[0].module.sharded_state_dict() - dist_checkpointing.load(ref_state_dict, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED) + dist_checkpointing.load( + ref_state_dict, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED + ) dut_state_dict = model[0].module.state_dict() for name in dut_state_dict.keys(): @@ -68,7 +88,9 @@ def test_conversion(megatron_model_provider, tfconfig, output_path, model): ref_data = ref_data.data.view(ref_data.local_shape) else: ref_data = ref_data.data - assert dut_data.shape == ref_data.shape, f"{name=} {dut_data.shape=} {ref_data.shape=}" + assert ( + dut_data.shape == ref_data.shape + ), f"{name=} {dut_data.shape=} {ref_data.shape=}" assert (dut_data == ref_data).all(), f"{name} is not equal" print(f"{name} is equal") else: @@ -84,7 +106,9 @@ def test_conversion(megatron_model_provider, tfconfig, output_path, model): ref_data = ref_data.data if name in dut_state_dict: dut_data = dut_state_dict[name].data - assert dut_data.shape == ref_data.shape, f"{name=} {dut_data.shape=} {ref_data.shape=}" + assert ( + dut_data.shape == ref_data.shape + ), f"{name=} {dut_data.shape=} {ref_data.shape=}" assert (dut_data == ref_data).all(), f"{name} is not equal" print(f"{name} is equal") else: @@ -99,18 +123,32 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config) head_dim = getattr(hf_config, "head_dim", hidden_dim // num_attention_heads) if num_attention_heads != num_key_value_heads: print("[WARNING] Converting GQA model") - has_qkv_bias = getattr(hf_config, "qkv_bias", False) or getattr(hf_config, "attention_bias", False) + has_qkv_bias = getattr(hf_config, "qkv_bias", False) or getattr( + hf_config, "attention_bias", False + ) has_share_expert = getattr(hf_config, "shared_expert_intermediate_size", None) with torch.no_grad(): model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight) - for layer, hf_layer in zip(model.decoder.layers, hf_model.model.layers, strict=True): - layer.self_attention.linear_qkv.layer_norm_weight.copy_(hf_layer.input_layernorm.weight) + for layer, hf_layer in zip( + model.decoder.layers, hf_model.model.layers, strict=True + ): + layer.self_attention.linear_qkv.layer_norm_weight.copy_( + hf_layer.input_layernorm.weight + ) q = hf_layer.self_attn.q_proj.weight.view( - [num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1] + [ + num_key_value_heads, + head_dim * num_attention_heads // num_key_value_heads, + -1, + ] + ) + k = hf_layer.self_attn.k_proj.weight.view( + [num_key_value_heads, head_dim, -1] + ) + v = hf_layer.self_attn.v_proj.weight.view( + [num_key_value_heads, head_dim, -1] ) - k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1]) - v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1]) qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous() layer.self_attention.linear_qkv.weight.copy_(qkv) @@ -118,30 +156,53 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config) q_bias = hf_layer.self_attn.q_proj.bias.view([num_key_value_heads, -1]) k_bias = hf_layer.self_attn.k_proj.bias.view([num_key_value_heads, -1]) v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1]) - qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous() + qkv_bias = ( + torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous() + ) layer.self_attention.linear_qkv.bias.copy_(qkv_bias) if hasattr(hf_layer.self_attn, "q_norm"): - layer.self_attention.q_layernorm.weight.copy_(hf_layer.self_attn.q_norm.weight.data) - layer.self_attention.k_layernorm.weight.copy_(hf_layer.self_attn.k_norm.weight.data) + layer.self_attention.q_layernorm.weight.copy_( + hf_layer.self_attn.q_norm.weight.data + ) + layer.self_attention.k_layernorm.weight.copy_( + hf_layer.self_attn.k_norm.weight.data + ) - layer.self_attention.linear_proj.weight.copy_(hf_layer.self_attn.o_proj.weight) - layer.pre_mlp_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight) + layer.self_attention.linear_proj.weight.copy_( + hf_layer.self_attn.o_proj.weight + ) + layer.pre_mlp_layernorm.weight.copy_( + hf_layer.post_attention_layernorm.weight + ) layer.mlp.router.weight.copy_(hf_layer.mlp.gate.weight) for idx, hf_expert in enumerate(hf_layer.mlp.experts): - fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) - layer.mlp.experts.linear_fc1._parameters[f"weight{idx}"].copy_(fc1_weight) - layer.mlp.experts.linear_fc2._parameters[f"weight{idx}"].copy_(hf_expert.down_proj.weight) + fc1_weight = torch.cat( + [hf_expert.gate_proj.weight, hf_expert.up_proj.weight] + ) + layer.mlp.experts.linear_fc1._parameters[f"weight{idx}"].copy_( + fc1_weight + ) + layer.mlp.experts.linear_fc2._parameters[f"weight{idx}"].copy_( + hf_expert.down_proj.weight + ) if has_share_expert: - layer.mlp.shared_experts.gate_weight.copy_(hf_layer.mlp.shared_expert_gate.weight) + layer.mlp.shared_experts.gate_weight.copy_( + hf_layer.mlp.shared_expert_gate.weight + ) shared_fc1_weight = torch.cat( - [hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight] + [ + hf_layer.mlp.shared_expert.gate_proj.weight, + hf_layer.mlp.shared_expert.up_proj.weight, + ] ) layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight) - layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_expert.down_proj.weight) + layer.mlp.shared_experts.linear_fc2.weight.copy_( + hf_layer.mlp.shared_expert.down_proj.weight + ) model.decoder.final_layernorm.weight.copy_(hf_model.model.norm.weight) model.output_layer.weight.copy_(hf_model.lm_head.weight) @@ -154,14 +215,18 @@ def safe_copy( ): if not skip_dtype_assert: if src_tensor.dtype != dst_tensor.dtype: - raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}") + raise ValueError( + f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}" + ) assert src_tensor.shape == dst_tensor.shape dst_tensor.data.copy_(src_tensor.data) return src_tensor.numel() @torch.inference_mode() -def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel, hf_config): +def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl( + hfmodel, mgmodel, hf_config +): mgmodel = mgmodel.bfloat16() hfmodel = hfmodel.bfloat16() num_attention_heads = hf_config.num_attention_heads @@ -177,21 +242,31 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel vision_head_dim = vision_hidden_size // mgvision.config.num_attention_heads copied_numel = 0 safe_copy(hfvision.rotary_pos_emb.inv_freq, mgvision.rotary_pos_emb.inv_freq) - copied_numel += safe_copy(hfvision.patch_embed.proj.weight, mgvision.patch_embed.proj.weight) + copied_numel += safe_copy( + hfvision.patch_embed.proj.weight, mgvision.patch_embed.proj.weight + ) for hfblock, mgblock in zip(hfvision.blocks, mgvision.decoder.layers, strict=True): # norm1 --> linear_qkv.norm - copied_numel += safe_copy(hfblock.norm1.weight, mgblock.self_attention.linear_qkv.layer_norm_weight) + copied_numel += safe_copy( + hfblock.norm1.weight, mgblock.self_attention.linear_qkv.layer_norm_weight + ) # norm2 --> mlp.linear_fc1.norm - copied_numel += safe_copy(hfblock.norm2.weight, mgblock.mlp.linear_fc1.layer_norm_weight) + copied_numel += safe_copy( + hfblock.norm2.weight, mgblock.mlp.linear_fc1.layer_norm_weight + ) # qkv --> self_attention.linear_qkv converted_weight = ( - hfblock.attn.qkv.weight.view(3, vision_num_query_groups, -1, vision_head_dim, vision_hidden_size) + hfblock.attn.qkv.weight.view( + 3, vision_num_query_groups, -1, vision_head_dim, vision_hidden_size + ) .transpose(0, 1) .flatten(1, 2) .reshape(-1, vision_hidden_size) .contiguous() ) - copied_numel += safe_copy(converted_weight, mgblock.self_attention.linear_qkv.weight) + copied_numel += safe_copy( + converted_weight, mgblock.self_attention.linear_qkv.weight + ) converted_bias = ( hfblock.attn.qkv.bias.view(3, vision_num_query_groups, -1) .transpose(0, 1) @@ -199,55 +274,105 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel .view(-1) .contiguous() ) - copied_numel += safe_copy(converted_bias, mgblock.self_attention.linear_qkv.bias) + copied_numel += safe_copy( + converted_bias, mgblock.self_attention.linear_qkv.bias + ) # proj --> self_attention.linear_proj - copied_numel += safe_copy(hfblock.attn.proj.weight, mgblock.self_attention.linear_proj.weight) - copied_numel += safe_copy(hfblock.attn.proj.bias, mgblock.self_attention.linear_proj.bias) + copied_numel += safe_copy( + hfblock.attn.proj.weight, mgblock.self_attention.linear_proj.weight + ) + copied_numel += safe_copy( + hfblock.attn.proj.bias, mgblock.self_attention.linear_proj.bias + ) # mlp --> mlp: gate - fc1_weight = torch.cat([hfblock.mlp.gate_proj.weight, hfblock.mlp.up_proj.weight]) + fc1_weight = torch.cat( + [hfblock.mlp.gate_proj.weight, hfblock.mlp.up_proj.weight] + ) fc1_bias = torch.cat([hfblock.mlp.gate_proj.bias, hfblock.mlp.up_proj.bias]) copied_numel += safe_copy(fc1_weight, mgblock.mlp.linear_fc1.weight) copied_numel += safe_copy(fc1_bias, mgblock.mlp.linear_fc1.bias) - copied_numel += safe_copy(hfblock.mlp.down_proj.weight, mgblock.mlp.linear_fc2.weight) - copied_numel += safe_copy(hfblock.mlp.down_proj.bias, mgblock.mlp.linear_fc2.bias) + copied_numel += safe_copy( + hfblock.mlp.down_proj.weight, mgblock.mlp.linear_fc2.weight + ) + copied_numel += safe_copy( + hfblock.mlp.down_proj.bias, mgblock.mlp.linear_fc2.bias + ) # 2. vision projector hfprojector = hfvision.merger mgprojector = mgvision.projection - copied_numel += safe_copy(hfprojector.ln_q.weight, mgvision.decoder.final_layernorm.weight) + copied_numel += safe_copy( + hfprojector.ln_q.weight, mgvision.decoder.final_layernorm.weight + ) - copied_numel += safe_copy(hfprojector.mlp[0].weight, mgprojector.encoder.linear_fc1.weight) - copied_numel += safe_copy(hfprojector.mlp[0].bias, mgprojector.encoder.linear_fc1.bias) - copied_numel += safe_copy(hfprojector.mlp[2].weight, mgprojector.encoder.linear_fc2.weight) - copied_numel += safe_copy(hfprojector.mlp[2].bias, mgprojector.encoder.linear_fc2.bias) + copied_numel += safe_copy( + hfprojector.mlp[0].weight, mgprojector.encoder.linear_fc1.weight + ) + copied_numel += safe_copy( + hfprojector.mlp[0].bias, mgprojector.encoder.linear_fc1.bias + ) + copied_numel += safe_copy( + hfprojector.mlp[2].weight, mgprojector.encoder.linear_fc2.weight + ) + copied_numel += safe_copy( + hfprojector.mlp[2].bias, mgprojector.encoder.linear_fc2.bias + ) n_params = sum([t.numel() for t in hfvision.state_dict().values()]) assert n_params == copied_numel # 3. llm [just Qwen2] hfllm = hfmodel.model mgllm = mgmodel.language_model copied_numel = 0 - copied_numel += safe_copy(hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight) + copied_numel += safe_copy( + hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight + ) for mglayer, hflayer in zip(mgllm.decoder.layers, hfllm.layers, strict=True): - copied_numel += safe_copy(hflayer.input_layernorm.weight, mglayer.self_attention.linear_qkv.layer_norm_weight) + copied_numel += safe_copy( + hflayer.input_layernorm.weight, + mglayer.self_attention.linear_qkv.layer_norm_weight, + ) - q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) - k_proj_weight = hflayer.self_attn.k_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) - v_proj_weight = hflayer.self_attn.v_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) - qkv_proj = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=1).view(-1, hidden_size).contiguous() + q_proj_weight = hflayer.self_attn.q_proj.weight.view( + num_query_groups, -1, head_dim, hidden_size + ) + k_proj_weight = hflayer.self_attn.k_proj.weight.view( + num_query_groups, -1, head_dim, hidden_size + ) + v_proj_weight = hflayer.self_attn.v_proj.weight.view( + num_query_groups, -1, head_dim, hidden_size + ) + qkv_proj = ( + torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=1) + .view(-1, hidden_size) + .contiguous() + ) copied_numel += safe_copy(qkv_proj, mglayer.self_attention.linear_qkv.weight) q_proj_bias = hflayer.self_attn.q_proj.bias.view(num_query_groups, -1) k_proj_bias = hflayer.self_attn.k_proj.bias.view(num_query_groups, -1) v_proj_bias = hflayer.self_attn.v_proj.bias.view(num_query_groups, -1) - qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=1).view(-1).contiguous() + qkv_bias = ( + torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=1) + .view(-1) + .contiguous() + ) copied_numel += safe_copy(qkv_bias, mglayer.self_attention.linear_qkv.bias) - copied_numel += safe_copy(hflayer.self_attn.o_proj.weight, mglayer.self_attention.linear_proj.weight) + copied_numel += safe_copy( + hflayer.self_attn.o_proj.weight, mglayer.self_attention.linear_proj.weight + ) - fc1_weight = torch.cat([hflayer.mlp.gate_proj.weight, hflayer.mlp.up_proj.weight]) + fc1_weight = torch.cat( + [hflayer.mlp.gate_proj.weight, hflayer.mlp.up_proj.weight] + ) copied_numel += safe_copy(fc1_weight, mglayer.mlp.linear_fc1.weight) - copied_numel += safe_copy(hflayer.mlp.down_proj.weight, mglayer.mlp.linear_fc2.weight) - copied_numel += safe_copy(hflayer.post_attention_layernorm.weight, mglayer.mlp.linear_fc1.layer_norm_weight) + copied_numel += safe_copy( + hflayer.mlp.down_proj.weight, mglayer.mlp.linear_fc2.weight + ) + copied_numel += safe_copy( + hflayer.post_attention_layernorm.weight, + mglayer.mlp.linear_fc1.layer_norm_weight, + ) copied_numel += safe_copy(hfllm.norm.weight, mgllm.decoder.final_layernorm.weight) if not hf_config.tie_word_embeddings: @@ -259,65 +384,118 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel @torch.no_grad() -def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_config, tfconfig): +def convert_checkpoint_from_transformers_to_megatron_dpskv3( + hf_model, model, hf_config, tfconfig +): warnings.warn("MTP model is not supported yet", stacklevel=2) numel: int = 0 - numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight) + numel += safe_copy( + hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight + ) print(f"{numel=}") - for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers, strict=True)): + for layer_idx, (layer, hf_layer) in enumerate( + zip(model.decoder.layers, hf_model.model.layers, strict=True) + ): numel_cur: int = numel - numel += safe_copy(hf_layer.input_layernorm.weight, layer.input_layernorm.weight) + numel += safe_copy( + hf_layer.input_layernorm.weight, layer.input_layernorm.weight + ) if hf_config.q_lora_rank is None: - numel += safe_copy(hf_layer.self_attn.q_proj.weight, layer.self_attention.linear_q_proj.weight) + numel += safe_copy( + hf_layer.self_attn.q_proj.weight, + layer.self_attention.linear_q_proj.weight, + ) else: - numel += safe_copy(hf_layer.self_attn.q_a_proj.weight, layer.self_attention.linear_q_down_proj.weight) - numel += safe_copy(hf_layer.self_attn.q_b_proj.weight, layer.self_attention.linear_q_up_proj.weight) numel += safe_copy( - hf_layer.self_attn.q_a_layernorm.weight, layer.self_attention.linear_q_up_proj.layer_norm_weight + hf_layer.self_attn.q_a_proj.weight, + layer.self_attention.linear_q_down_proj.weight, + ) + numel += safe_copy( + hf_layer.self_attn.q_b_proj.weight, + layer.self_attention.linear_q_up_proj.weight, + ) + numel += safe_copy( + hf_layer.self_attn.q_a_layernorm.weight, + layer.self_attention.linear_q_up_proj.layer_norm_weight, ) numel += safe_copy( - hf_layer.self_attn.kv_a_proj_with_mqa.weight, layer.self_attention.linear_kv_down_proj.weight + hf_layer.self_attn.kv_a_proj_with_mqa.weight, + layer.self_attention.linear_kv_down_proj.weight, ) - numel += safe_copy(hf_layer.self_attn.kv_b_proj.weight, layer.self_attention.linear_kv_up_proj.weight) numel += safe_copy( - hf_layer.self_attn.kv_a_layernorm.weight, layer.self_attention.linear_kv_up_proj.layer_norm_weight + hf_layer.self_attn.kv_b_proj.weight, + layer.self_attention.linear_kv_up_proj.weight, + ) + numel += safe_copy( + hf_layer.self_attn.kv_a_layernorm.weight, + layer.self_attention.linear_kv_up_proj.layer_norm_weight, + ) + numel += safe_copy( + hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight ) - numel += safe_copy(hf_layer.self_attn.o_proj.weight, layer.self_attention.linear_proj.weight) if not hasattr(layer.mlp, "router"): - numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.mlp.linear_fc1.layer_norm_weight) numel += safe_copy( - torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight]), layer.mlp.linear_fc1.weight + hf_layer.post_attention_layernorm.weight, + layer.mlp.linear_fc1.layer_norm_weight, + ) + numel += safe_copy( + torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight]), + layer.mlp.linear_fc1.weight, + ) + numel += safe_copy( + hf_layer.mlp.down_proj.weight, layer.mlp.linear_fc2.weight ) - numel += safe_copy(hf_layer.mlp.down_proj.weight, layer.mlp.linear_fc2.weight) else: numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight) # NOTE: the e_score_correction_bias in mcore model will be initialized with bfloat16 and \ # recover to fp32 in the first forward. There is always a diff in the bias between two models (~0.3%) numel += safe_copy( - hf_layer.mlp.gate.e_score_correction_bias, layer.mlp.router.expert_bias, skip_dtype_assert=True + hf_layer.mlp.gate.e_score_correction_bias, + layer.mlp.router.expert_bias, + skip_dtype_assert=True, ) if tfconfig.moe_grouped_gemm: for i, hf_expert in enumerate(hf_layer.mlp.experts): - fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) - linear_fc1_weighti = getattr(layer.mlp.experts.linear_fc1, "weight" + str(i)) + fc1_weight = torch.cat( + [hf_expert.gate_proj.weight, hf_expert.up_proj.weight] + ) + linear_fc1_weighti = getattr( + layer.mlp.experts.linear_fc1, "weight" + str(i) + ) numel += safe_copy(fc1_weight, linear_fc1_weighti) - linear_fc2_weighti = getattr(layer.mlp.experts.linear_fc2, "weight" + str(i)) + linear_fc2_weighti = getattr( + layer.mlp.experts.linear_fc2, "weight" + str(i) + ) numel += safe_copy(hf_expert.down_proj.weight, linear_fc2_weighti) else: for i, hf_expert in enumerate(hf_layer.mlp.experts): expert = layer.mlp.experts.local_experts[i] - fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) + fc1_weight = torch.cat( + [hf_expert.gate_proj.weight, hf_expert.up_proj.weight] + ) numel += safe_copy(fc1_weight, expert.linear_fc1.weight) - numel += safe_copy(hf_expert.down_proj.weight, expert.linear_fc2.weight) - numel += safe_copy(hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight) + numel += safe_copy( + hf_expert.down_proj.weight, expert.linear_fc2.weight + ) + numel += safe_copy( + hf_layer.post_attention_layernorm.weight, layer.pre_mlp_layernorm.weight + ) shared_fc1_weight = torch.cat( - [hf_layer.mlp.shared_experts.gate_proj.weight, hf_layer.mlp.shared_experts.up_proj.weight] + [ + hf_layer.mlp.shared_experts.gate_proj.weight, + hf_layer.mlp.shared_experts.up_proj.weight, + ] + ) + numel += safe_copy( + shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight + ) + numel += safe_copy( + hf_layer.mlp.shared_experts.down_proj.weight, + layer.mlp.shared_experts.linear_fc2.weight, ) - numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight) - numel += safe_copy(hf_layer.mlp.shared_experts.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight) print(f"{layer_idx=} {numel=} numel this layer={numel - numel_cur}") numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight) @@ -333,7 +511,13 @@ def noop_context() -> Any: yield -def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False): +def convert_hf_to_mcore( + hf_model_path, + output_path, + use_cpu_initialization=False, + test=False, + trust_remote_code=False, +): os.makedirs(output_path, exist_ok=True) if len(os.listdir(output_path)) > 0 and not test: print(f"Output path {output_path} is not empty, skipping conversion") @@ -375,7 +559,9 @@ def megatron_model_provider(pre_process, post_process): ) return parallel_model - context: Callable[..., ContextManager] = init_empty_weights if use_cpu_initialization else noop_context + context: Callable[..., ContextManager] = ( + init_empty_weights if use_cpu_initialization else noop_context + ) with context(): model = get_model( model_provider_func=megatron_model_provider, @@ -395,29 +581,44 @@ def megatron_model_provider(pre_process, post_process): # init hf model if "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures: hf_model = AutoModelForImageTextToText.from_pretrained( - hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code + hf_model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, ) else: hf_model = AutoModelForCausalLM.from_pretrained( - hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code + hf_model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, ) hf_state_dict = hf_model.state_dict() # load hf state dict to megatron model if "Qwen2MoeForCausalLM" in hf_config.architectures: - convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config) + convert_checkpoint_from_transformers_to_megatron( + hf_model, model[0].module, hf_config + ) elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures: - convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hf_model, model[0].module, hf_config) + convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl( + hf_model, model[0].module, hf_config + ) elif "DeepseekV3ForCausalLM" in hf_config.architectures: numel: int = convert_checkpoint_from_transformers_to_megatron_dpskv3( hf_model, model[0].module, hf_config, tfconfig=tfconfig ) if numel != hf_model.num_parameters(): - warnings.warn(f"numel mismatch: {numel=} != {hf_model.num_parameters()=}", stacklevel=1) + warnings.warn( + f"numel mismatch: {numel=} != {hf_model.num_parameters()=}", + stacklevel=1, + ) elif "Qwen3MoeForCausalLM" in hf_config.architectures: - convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config) + convert_checkpoint_from_transformers_to_megatron( + hf_model, model[0].module, hf_config + ) else: - assert not use_cpu_initialization, "use_cpu_initialization is only supported for MoE model" + assert ( + not use_cpu_initialization + ), "use_cpu_initialization is only supported for MoE model" from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel load_state_dict_to_megatron_gptmodel( @@ -433,7 +634,12 @@ def megatron_model_provider(pre_process, post_process): # save megatron model if len(os.listdir(output_path)) == 0: - dist_checkpointing.save(megatron_state_dict, output_path, sharded_strategy=None, async_sharded_save=False) + dist_checkpointing.save( + megatron_state_dict, + output_path, + sharded_strategy=None, + async_sharded_save=False, + ) if test: test_conversion(megatron_model_provider, tfconfig, output_path, model) @@ -441,5 +647,9 @@ def megatron_model_provider(pre_process, post_process): if __name__ == "__main__": args = _init_args() convert_hf_to_mcore( - args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test, args.trust_remote_code + args.hf_model_path, + args.output_path, + args.use_cpu_initialization, + args.test, + args.trust_remote_code, ) diff --git a/Agent0/executor_train/verl/scripts/diagnose.py b/Agent0/executor_train/verl/scripts/diagnose.py index 174b1f9..8a64e3d 100644 --- a/Agent0/executor_train/verl/scripts/diagnose.py +++ b/Agent0/executor_train/verl/scripts/diagnose.py @@ -61,10 +61,18 @@ def test_connection(name, url, timeout=10): try: _ = urlopen(url, timeout=timeout) except Exception as e: - print("Error open {}: {}, {}, DNS finished in {} sec.".format(name, url, e, dns_elapsed)) + print( + "Error open {}: {}, {}, DNS finished in {} sec.".format( + name, url, e, dns_elapsed + ) + ) return load_elapsed = time.time() - start - print("Timing for {}: {}, DNS: {:.4f} sec, LOAD: {:.4f} sec.".format(name, url, dns_elapsed, load_elapsed)) + print( + "Timing for {}: {}, DNS: {:.4f} sec, LOAD: {:.4f} sec.".format( + name, url, dns_elapsed, load_elapsed + ) + ) def check_python(): @@ -88,7 +96,9 @@ def check_pip(): def _get_current_git_commit(): try: - result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True) + result = subprocess.run( + ["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True + ) return result.stdout.strip() except subprocess.CalledProcessError as e: print(f"Error running git command: {e.stderr.strip()}") @@ -162,7 +172,12 @@ def check_network(args): else: import warnings - warnings.warn("Region {} do not need specific test, please refer to global sites.".format(r), stacklevel=2) + warnings.warn( + "Region {} do not need specific test, please refer to global sites.".format( + r + ), + stacklevel=2, + ) for name, url in URLS.items(): test_connection(name, url, args.timeout) @@ -170,7 +185,13 @@ def check_network(args): def check_environment(): print("----------Environment----------") for k, v in os.environ.items(): - if k.startswith("VERL_") or k.startswith("OMP_") or k.startswith("KMP_") or k == "CC" or k == "CXX": + if ( + k.startswith("VERL_") + or k.startswith("OMP_") + or k.startswith("KMP_") + or k == "CC" + or k == "CXX" + ): print('{}="{}"'.format(k, v)) @@ -192,7 +213,9 @@ def check_cuda_versions(): import subprocess nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8") - cuda_compiler_version = next((line for line in nvcc_output.splitlines() if "release" in line), None) + cuda_compiler_version = next( + (line for line in nvcc_output.splitlines() if "release" in line), None + ) if cuda_compiler_version: print(f"CUDA Compiler : {cuda_compiler_version.strip()}") else: @@ -219,7 +242,11 @@ def _get_gpu_info(): """ try: result = subprocess.run( - ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader,nounits"], + [ + "nvidia-smi", + "--query-gpu=gpu_name,memory.total", + "--format=csv,noheader,nounits", + ], capture_output=True, text=True, check=True, @@ -268,7 +295,9 @@ def parse_args(): ) choices = ["python", "pip", "verl", "system", "os", "environment"] for choice in choices: - parser.add_argument("--" + choice, default=1, type=int, help="Diagnose {}.".format(choice)) + parser.add_argument( + "--" + choice, default=1, type=int, help="Diagnose {}.".format(choice) + ) parser.add_argument("--network", default=0, type=int, help="Diagnose network.") parser.add_argument("--hardware", default=0, type=int, help="Diagnose hardware.") parser.add_argument( @@ -278,7 +307,12 @@ def parse_args(): help="Additional sites in which region(s) to test. \ Specify 'cn' for example to test mirror sites in China.", ) - parser.add_argument("--timeout", default=10, type=int, help="Connection test timeout threshold, 0 to disable.") + parser.add_argument( + "--timeout", + default=10, + type=int, + help="Connection test timeout threshold, 0 to disable.", + ) args = parser.parse_args() return args diff --git a/Agent0/executor_train/verl/scripts/init_random_model.py b/Agent0/executor_train/verl/scripts/init_random_model.py index 2804bc2..cc9f068 100644 --- a/Agent0/executor_train/verl/scripts/init_random_model.py +++ b/Agent0/executor_train/verl/scripts/init_random_model.py @@ -31,21 +31,44 @@ import warnings from typing import Any -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + PretrainedConfig, +) def _init_args(): parser = argparse.ArgumentParser() - parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model") - parser.add_argument("--new_config_path", type=str, required=True, help="The path for the new config file") - parser.add_argument("--output_path", type=str, required=True, help="The path for the output random model") + parser.add_argument( + "--hf_model_path", + type=str, + required=True, + help="The path for the huggingface model", + ) + parser.add_argument( + "--new_config_path", + type=str, + required=True, + help="The path for the new config file", + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="The path for the output random model", + ) args = parser.parse_args() return args def check_output_path(output_path: str): if os.path.exists(output_path): - warnings.warn(f"Output path '{output_path}' already exists. Will do nothing.", stacklevel=2) + warnings.warn( + f"Output path '{output_path}' already exists. Will do nothing.", + stacklevel=2, + ) exit() else: os.makedirs(output_path, exist_ok=True) @@ -58,14 +81,15 @@ def check_configs(original_config: dict[str, Any], new_config: dict[str, Any]) - This is a placeholder function; actual implementation may vary based on requirements. """ # Example check: ensure 'model_type' is the same - if new_config.get("model_type", None) is not None and original_config.get("model_type") != new_config.get( + if new_config.get("model_type", None) is not None and original_config.get( "model_type" - ): + ) != new_config.get("model_type"): raise RuntimeError("Model types do not match.") for key in new_config: if key not in original_config: warnings.warn( - f"Key '{key}' in new config does not exist in original config, may not take effect.", stacklevel=2 + f"Key '{key}' in new config does not exist in original config, may not take effect.", + stacklevel=2, ) @@ -91,5 +115,7 @@ def init_random_model(hf_model_path, new_config_path, output_path): args = _init_args() check_output_path(args.output_path) init_random_model( - hf_model_path=args.hf_model_path, new_config_path=args.new_config_path, output_path=args.output_path + hf_model_path=args.hf_model_path, + new_config_path=args.new_config_path, + output_path=args.output_path, ) diff --git a/Agent0/executor_train/verl/scripts/legacy_model_merger.py b/Agent0/executor_train/verl/scripts/legacy_model_merger.py index 8a5224a..26c2684 100644 --- a/Agent0/executor_train/verl/scripts/legacy_model_merger.py +++ b/Agent0/executor_train/verl/scripts/legacy_model_merger.py @@ -115,7 +115,9 @@ def get_transformers_auto_model_class(self): elif "ForConditionalGeneration" in self.model_config.architectures[0]: return AutoModelForVision2Seq - raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") + raise NotImplementedError( + f"Unknown architecture {self.model_config.architectures}" + ) def patch_model_generation_config(self, model): """ @@ -126,7 +128,9 @@ def patch_model_generation_config(self, model): """ if model.can_generate(): try: - model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) + model.generation_config = GenerationConfig.from_pretrained( + self.hf_model_config_path + ) except OSError: print( f"Warning: Generation config file not found in {self.hf_model_config_path}, using a generation config created from the model config." @@ -170,13 +174,19 @@ def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): "target_modules": list(target_modules), } peft_config = peft.LoraConfig(**peft_dict).to_dict() - peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None - peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None + peft_config["task_type"] = ( + peft_config["task_type"].value if peft_config["task_type"] else None + ) + peft_config["peft_type"] = ( + peft_config["peft_type"].value if peft_config["peft_type"] else None + ) peft_config["target_modules"] = list(peft_config["target_modules"]) lora_path = os.path.join(self.config.target_dir, "lora_adapter") os.makedirs(lora_path, exist_ok=True) - with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f: + with open( + os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8" + ) as f: json.dump(peft_config, f, ensure_ascii=False, indent=4) save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors")) @@ -193,7 +203,9 @@ def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): auto_model_class = self.get_transformers_auto_model_class() with init_empty_weights(): - model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16) + model = auto_model_class.from_config( + self.model_config, torch_dtype=torch.bfloat16 + ) model.to_empty(device="cpu") model = self.patch_model_generation_config(model) @@ -219,8 +231,16 @@ def upload_to_huggingface(self): from huggingface_hub import HfApi api = HfApi() - api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) - api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") + api.create_repo( + repo_id=self.config.hf_upload_path, + private=self.config.private, + exist_ok=True, + ) + api.upload_folder( + folder_path=self.config.target_dir, + repo_id=self.config.hf_upload_path, + repo_type="model", + ) @abstractmethod def merge_and_save(self): @@ -245,7 +265,9 @@ def _load_rank_zero_state_dict(self, world_size: int) -> dict: weights_only=False, ) - def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: + def _extract_device_mesh_info( + self, state_dict: dict, world_size: int + ) -> tuple[np.ndarray, tuple[str, ...]]: """ Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. If no DTensor is found, infers a simple FSDP mesh based on world_size. @@ -269,7 +291,10 @@ def _calculate_shard_configuration( self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...] ) -> tuple[int, tuple[int, ...]]: """Calculates the total number of shards and the shape of the device mesh.""" - assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" + assert mesh_dim_names in ( + ("fsdp",), + ("ddp", "fsdp"), + ), f"Unsupported mesh_dim_names {mesh_dim_names}" if "tp" in mesh_dim_names: # TODO: "tp" is not supported yet due to the above assert @@ -281,7 +306,9 @@ def _calculate_shard_configuration( return total_shards, mesh_shape - def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: + def _merge_by_placement( + self, tensors: list[torch.Tensor], placement: Placement + ) -> torch.Tensor: """Merges a list of tensors based on their DTensor placement""" if placement.is_replicate(): return tensors[0] @@ -293,19 +320,31 @@ def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) raise NotImplementedError(f"Unsupported placement: {placement}") def _load_and_merge_state_dicts( - self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...] + self, + world_size: int, + total_shards: int, + mesh_shape: tuple[int, ...], + mesh_dim_names: tuple[str, ...], ) -> dict[str, torch.Tensor]: model_state_dict_lst = [None] * total_shards def process_one_shard(rank: int, model_state_dict_lst: list): - model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" + model_path = ( + Path(self.config.local_dir) + / f"model_world_size_{world_size}_rank_{rank}.pt" + ) state_dict = torch.load(model_path, map_location="cpu", weights_only=False) model_state_dict_lst[rank] = state_dict return state_dict with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] - for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): + futures = [ + executor.submit(process_one_shard, rank, model_state_dict_lst) + for rank in range(total_shards) + ] + for future in tqdm( + futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards + ): future.result() # Merge state dicts from all shards @@ -359,13 +398,19 @@ def merge_and_save(self): world_size = self._get_world_size() rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) - mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) + mesh, mesh_dim_names = self._extract_device_mesh_info( + rank_zero_state_dict, world_size + ) print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) + total_shards, mesh_shape = self._calculate_shard_configuration( + mesh, mesh_dim_names + ) print(f"Processing model shards with {total_shards} {mesh_shape} in total") - merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) + merged_state_dict = self._load_and_merge_state_dicts( + world_size, total_shards, mesh_shape, mesh_dim_names + ) if self.config.operation == "test": if not self.config.test_hf_dir: @@ -381,7 +426,9 @@ def merge_and_save(self): def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): auto_model_class = self.get_transformers_auto_model_class() - hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) + hf_model = auto_model_class.from_pretrained( + self.config.test_hf_dir, torch_dtype=torch.bfloat16 + ) hf_state_dict = hf_model.state_dict() del hf_model @@ -389,34 +436,46 @@ def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): collected_keys = set(state_dict.keys()) missing_keys = hf_model_keys - collected_keys - assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" + assert ( + len(missing_keys) == 0 + ), f"Missing keys in collected state dict: {list(sorted(missing_keys))}" extra_keys = collected_keys - hf_model_keys - assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" + assert ( + len(extra_keys) == 0 + ), f"Extra keys in collected state dict: {list(sorted(extra_keys))}" for key in hf_model_keys: hf_shape = hf_state_dict[key].shape collected_shape = state_dict[key].shape - assert hf_shape == collected_shape, ( - f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" - ) + assert ( + hf_shape == collected_shape + ), f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" hf_dtype = hf_state_dict[key].dtype collected_dtype = state_dict[key].dtype - assert hf_dtype == collected_dtype, ( - f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" - ) + assert ( + hf_dtype == collected_dtype + ), f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" - torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) + torch.testing.assert_close( + hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6 + ) - print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") + print( + "FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager." + ) class MegatronModelMerger(BaseModelMerger): def __init__(self, config: ModelMergerConfig): - from verl.utils.megatron_utils import get_hf_config_and_tokenizer_checkpoint_path + from verl.utils.megatron_utils import ( + get_hf_config_and_tokenizer_checkpoint_path, + ) - config.hf_model_config_path = get_hf_config_and_tokenizer_checkpoint_path(config.local_dir) + config.hf_model_config_path = get_hf_config_and_tokenizer_checkpoint_path( + config.local_dir + ) super().__init__(config) self.params_mapping = { @@ -466,11 +525,15 @@ def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[int, int]: tp_rank = int(rank_list[0]) pp_rank = 0 - assert tp_rank is not None and pp_rank is not None, f"Invalid sharded dir {sharded_dir}" + assert ( + tp_rank is not None and pp_rank is not None + ), f"Invalid sharded dir {sharded_dir}" return tp_rank, pp_rank - def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[list[str], int, int]: + def _check_megatron_checkpoint_path( + self, model_path: str + ) -> tuple[list[str], int, int]: """ Validates the Megatron checkpoint structure (presence of 'model.pt' in sharded directories). Determines TP and PP sizes from directory names. @@ -479,7 +542,9 @@ def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[list[str], i pp_size = 0 sharded_dirs = sorted(os.listdir(model_path)) for sharded_dir in sharded_dirs: - assert "model.pt" in os.listdir(Path(model_path) / sharded_dir), f"model.pt not found in {sharded_dir}" + assert "model.pt" in os.listdir( + Path(model_path) / sharded_dir + ), f"model.pt not found in {sharded_dir}" tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) tp_size = max(tp_size, tp_rank + 1) pp_size = max(pp_size, pp_rank + 1) @@ -533,7 +598,12 @@ def _merge_across_tp( k = torch.cat(k_lst, dim=0) v = torch.cat(v_lst, dim=0) return [q, k, v] - elif "layer_norm" in key or "layernorm" in key or "router" in key or ("output_layer" in key and is_value_model): + elif ( + "layer_norm" in key + or "layernorm" in key + or "router" in key + or ("output_layer" in key and is_value_model) + ): return tp_data[0] else: dim = 0 @@ -548,13 +618,22 @@ def _load_state_dicts( def _process_one_megatron_shard(sharded_dir: str): model_file_path = Path(model_ckpt_path) / sharded_dir / "model.pt" - state_dict = torch.load(model_file_path, map_location="cpu", weights_only=False) + state_dict = torch.load( + model_file_path, map_location="cpu", weights_only=False + ) tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) model_state_dict_lst[pp_rank][tp_rank] = state_dict with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(_process_one_megatron_shard, sharded_dir) for sharded_dir in sharded_dirs] - for future in tqdm(futures, desc=f"Loading {len(sharded_dirs)} Megatron shards", total=len(sharded_dirs)): + futures = [ + executor.submit(_process_one_megatron_shard, sharded_dir) + for sharded_dir in sharded_dirs + ] + for future in tqdm( + futures, + desc=f"Loading {len(sharded_dirs)} Megatron shards", + total=len(sharded_dirs), + ): future.result() return model_state_dict_lst @@ -598,12 +677,16 @@ def _merge_state_dicts( if "extra_state" in key: continue if self.config.tie_word_embedding and ("output_layer" in key): - print("skip lm_head and reward_head loading because of tie_word_embeddings") + print( + "skip lm_head and reward_head loading because of tie_word_embeddings" + ) continue self._check_megatron_state_key(key) hf_name = self._replace_name(key, self.params_mapping) - assert hf_name is not None, f"Failed to convert layer name [{key}] from megatron to huggingface." + assert ( + hf_name is not None + ), f"Failed to convert layer name [{key}] from megatron to huggingface." if "model.layers." in hf_name: local_layer_no = int(hf_name.split(".")[2]) layers_handled = max(local_layer_no, layers_handled) @@ -612,10 +695,22 @@ def _merge_state_dicts( new_key_list[2] = str(global_layer_no) hf_name = ".".join(new_key_list) else: - warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2) - - tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)] - merged = self._merge_across_tp(key, tp_data, self.model_config, tp_size, self.config.is_value_model) + warnings.warn( + f"hf_name {hf_name} will not be fixed with layer number", + stacklevel=2, + ) + + tp_data = [ + model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] + for tp_rank in range(tp_size) + ] + merged = self._merge_across_tp( + key, + tp_data, + self.model_config, + tp_size, + self.config.is_value_model, + ) if not isinstance(merged, list): state_dict[hf_name] = merged @@ -639,11 +734,19 @@ def merge_and_save(self): from verl.utils.megatron_utils import get_model_checkpoint_path model_ckpt_path = get_model_checkpoint_path(self.config.local_dir) - sharded_dirs, tp_size, pp_size = self._check_megatron_checkpoint_path(model_ckpt_path) - print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}") + sharded_dirs, tp_size, pp_size = self._check_megatron_checkpoint_path( + model_ckpt_path + ) + print( + f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}" + ) - model_state_dict_lst = self._load_state_dicts(model_ckpt_path, sharded_dirs, tp_size, pp_size) - merged_state_dict = self._merge_state_dicts(model_state_dict_lst, tp_size, pp_size) + model_state_dict_lst = self._load_state_dicts( + model_ckpt_path, sharded_dirs, tp_size, pp_size + ) + merged_state_dict = self._merge_state_dicts( + model_state_dict_lst, tp_size, pp_size + ) del model_state_dict_lst if self.config.operation == "test": @@ -692,13 +795,24 @@ def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str def main(): parser = argparse.ArgumentParser(description="verl model merger") - subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") + subparsers = parser.add_subparsers( + dest="operation", required=True, help="Specify 'merge' or 'test' operation." + ) base_op_parser = argparse.ArgumentParser(add_help=False) base_op_parser.add_argument( - "--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model" + "--backend", + type=str, + required=True, + choices=["fsdp", "megatron"], + help="The backend of the model", + ) + base_op_parser.add_argument( + "--local_dir", + type=str, + required=True, + help="Path to the saved model checkpoints", ) - base_op_parser.add_argument("--local_dir", type=str, required=True, help="Path to the saved model checkpoints") base_op_parser.add_argument( "--hf_model_path", type=str, @@ -716,22 +830,37 @@ def main(): help="Whether the model is a value model (currently only Megatron supported)", ) - merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") + merge_parser = subparsers.add_parser( + "merge", parents=[base_op_parser], help="Merge model checkpoints and save." + ) merge_parser.add_argument( - "--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model" + "--target_dir", + default="tmp", + type=str, + help="Directory to save the merged huggingface model", ) merge_parser.add_argument( - "--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model" + "--hf_upload_path", + default=None, + type=str, + help="Hugging Face repository ID to upload the model", ) merge_parser.add_argument( - "--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository" + "--private", + action="store_true", + help="Whether to upload the model to a private Hugging Face repository", ) test_parser = subparsers.add_parser( - "test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model" + "test", + parents=[base_op_parser], + help="Test merged model against a reference Hugging Face model", ) test_parser.add_argument( - "--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing" + "--test_hf_dir", + type=str, + required=True, + help="Path to the reference Hugging Face model directory for testing", ) args = parser.parse_args() diff --git a/Agent0/executor_train/verl/tests/experimental/agent_loop/agent_utils.py b/Agent0/executor_train/verl/tests/experimental/agent_loop/agent_utils.py index 3c708c4..1f9211b 100644 --- a/Agent0/executor_train/verl/tests/experimental/agent_loop/agent_utils.py +++ b/Agent0/executor_train/verl/tests/experimental/agent_loop/agent_utils.py @@ -25,7 +25,9 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup: # =========================== 1. Create hybrid ActorRollout workers =========================== actor_rollout_cls = ( - AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker ) role_worker_mapping = { Role.ActorRollout: ray.remote(actor_rollout_cls), @@ -37,21 +39,29 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG mapping = { Role.ActorRollout: global_pool_id, } - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager = ResourcePoolManager( + resource_pool_spec=resource_pool_spec, mapping=mapping + ) resource_pool_manager.create_resource_pool() - resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} + resource_pool_to_cls = { + pool: {} for pool in resource_pool_manager.resource_pool_dict.values() + } # create actor and rollout resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) actor_rollout_cls = RayClassWithInitArgs( - cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout" + cls=role_worker_mapping[Role.ActorRollout], + config=config.actor_rollout_ref, + role="actor_rollout", ) resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls all_wg = {} for resource_pool, class_dict in resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) + wg_dict = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls + ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) actor_rollout_wg = all_wg["actor_rollout"] diff --git a/Agent0/executor_train/verl/tests/experimental/agent_loop/test_basic_agent_loop.py b/Agent0/executor_train/verl/tests/experimental/agent_loop/test_basic_agent_loop.py index 20936aa..88a540d 100644 --- a/Agent0/executor_train/verl/tests/experimental/agent_loop/test_basic_agent_loop.py +++ b/Agent0/executor_train/verl/tests/experimental/agent_loop/test_basic_agent_loop.py @@ -71,7 +71,12 @@ def test_single_turn(init_config): "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.", } ], - [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], + [ + { + "role": "user", + "content": "Let's play a role playing game. Your name is Bob, your favorite color is red.", + } + ], ] batch = DataProto( non_tensor_batch={ @@ -119,7 +124,9 @@ def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: schema = get_json_schema(self.get_current_temperature) return OpenAIFunctionToolSchema(**schema) - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: try: result = self.get_current_temperature(**parameters) return json.dumps(result), 0, {} @@ -150,7 +157,9 @@ def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): "unit": unit, } - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: try: result = self.get_temperature_date(**parameters) return json.dumps(result), 0, {} @@ -210,12 +219,17 @@ def test_tool_agent(init_config): "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" "Current Date: 2024-09-30", }, - {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, + { + "role": "user", + "content": "What's the temperature in San Francisco now? How about tomorrow?", + }, ], ] batch = DataProto( non_tensor_batch={ - "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + "raw_prompt": np.array( + [np.array(prompt) for prompt in raw_prompts], dtype=object + ), "agent_name": np.array(["tool_agent"] * len(raw_prompts)), }, ) @@ -238,14 +252,20 @@ def test_tool_agent(init_config): tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) responses = result.batch["responses"] response_mask = result.batch["response_mask"] - assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + assert ( + responses.size() == response_mask.size() + ), f"{responses.size()} != {response_mask.size()}" # Decode responses with response_mask for i in range(len(responses)): valid_tokens = responses[i][response_mask[i].bool()] response_str = tokenizer.decode(valid_tokens) - assert "" not in response_str, f"found in response: {response_str}" - assert "" not in response_str, f"found in response: {response_str}" + assert ( + "" not in response_str + ), f"found in response: {response_str}" + assert ( + "" not in response_str + ), f"found in response: {response_str}" print(f"response: {response_str}") print("Test passed!") diff --git a/Agent0/executor_train/verl/tests/interactions/test_gsm8k_interaction.py b/Agent0/executor_train/verl/tests/interactions/test_gsm8k_interaction.py index bc16877..60b022f 100644 --- a/Agent0/executor_train/verl/tests/interactions/test_gsm8k_interaction.py +++ b/Agent0/executor_train/verl/tests/interactions/test_gsm8k_interaction.py @@ -41,12 +41,16 @@ async def test_start_interaction_with_instance_id(self): instance_id = "test_instance" ground_truth = "42" - result_id = await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + result_id = await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) assert result_id == instance_id assert instance_id in self.interaction._instance_dict assert self.interaction._instance_dict[instance_id]["response"] == "" - assert self.interaction._instance_dict[instance_id]["ground_truth"] == ground_truth + assert ( + self.interaction._instance_dict[instance_id]["ground_truth"] == ground_truth + ) assert self.interaction._instance_dict[instance_id]["reward"] == 0.0 @pytest.mark.asyncio @@ -59,7 +63,9 @@ async def test_start_interaction_without_instance_id(self): assert result_id is not None assert len(result_id) == 36 # UUID4 length assert result_id in self.interaction._instance_dict - assert self.interaction._instance_dict[result_id]["ground_truth"] == ground_truth + assert ( + self.interaction._instance_dict[result_id]["ground_truth"] == ground_truth + ) @pytest.mark.asyncio async def test_start_interaction_without_ground_truth(self): @@ -78,13 +84,15 @@ async def test_generate_response_correct_answer_with_prefix(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) messages = [{"role": "user", "content": "#### 42"}] with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages + should_terminate, response, reward, metadata = ( + await self.interaction.generate_response(instance_id, messages) ) assert should_terminate is True @@ -100,13 +108,15 @@ async def test_generate_response_correct_answer_without_prefix(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) messages = [{"role": "user", "content": "42"}] with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages + should_terminate, response, reward, metadata = ( + await self.interaction.generate_response(instance_id, messages) ) assert should_terminate is True @@ -121,17 +131,22 @@ async def test_generate_response_incorrect_answer(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) messages = [{"role": "user", "content": "24"}] with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages + should_terminate, response, reward, metadata = ( + await self.interaction.generate_response(instance_id, messages) ) assert should_terminate is False - assert response == "Your response is incorrect! You need to reflect on your answer and try again." + assert ( + response + == "Your response is incorrect! You need to reflect on your answer and try again." + ) assert reward == 0.0 assert self.interaction._instance_dict[instance_id]["response"] == "#### 24" @@ -142,7 +157,9 @@ async def test_generate_response_multiple_messages(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) messages = [ {"role": "user", "content": "What is 2+2?"}, @@ -151,8 +168,8 @@ async def test_generate_response_multiple_messages(self): ] with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages + should_terminate, response, reward, metadata = ( + await self.interaction.generate_response(instance_id, messages) ) assert should_terminate is True @@ -166,13 +183,15 @@ async def test_generate_response_no_user_message(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) messages = [{"role": "assistant", "content": "Hello!"}] with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages + should_terminate, response, reward, metadata = ( + await self.interaction.generate_response(instance_id, messages) ) assert should_terminate is False @@ -185,16 +204,22 @@ async def test_calculate_score_direct_call(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) # Set a response self.interaction._instance_dict[instance_id]["response"] = "#### 42" - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0) as mock_compute: + with patch( + "verl.utils.reward_score.gsm8k.compute_score", return_value=1.0 + ) as mock_compute: score = await self.interaction.calculate_score(instance_id) assert score == 1.0 - mock_compute.assert_called_once_with("#### 42", "42", method="flexible", format_score=0.0, score=1.0) + mock_compute.assert_called_once_with( + "#### 42", "42", method="flexible", format_score=0.0, score=1.0 + ) @pytest.mark.asyncio async def test_calculate_score_with_kwargs(self): @@ -203,16 +228,24 @@ async def test_calculate_score_with_kwargs(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) # Set a response self.interaction._instance_dict[instance_id]["response"] = "#### 24" - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0) as mock_compute: - score = await self.interaction.calculate_score(instance_id, extra_param="test") + with patch( + "verl.utils.reward_score.gsm8k.compute_score", return_value=0.0 + ) as mock_compute: + score = await self.interaction.calculate_score( + instance_id, extra_param="test" + ) assert score == 0.0 - mock_compute.assert_called_once_with("#### 24", "42", method="flexible", format_score=0.0, score=1.0) + mock_compute.assert_called_once_with( + "#### 24", "42", method="flexible", format_score=0.0, score=1.0 + ) @pytest.mark.asyncio async def test_finalize_interaction(self): @@ -221,7 +254,9 @@ async def test_finalize_interaction(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) assert instance_id in self.interaction._instance_dict @@ -236,7 +271,9 @@ async def test_finalize_interaction_with_kwargs(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) assert instance_id in self.interaction._instance_dict @@ -259,14 +296,16 @@ async def test_full_interaction_workflow_correct(self): ground_truth = "42" # Start interaction - instance_id = await self.interaction.start_interaction(ground_truth=ground_truth) + instance_id = await self.interaction.start_interaction( + ground_truth=ground_truth + ) # Generate response with correct answer messages = [{"role": "user", "content": "42"}] with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages + should_terminate, response, reward, metadata = ( + await self.interaction.generate_response(instance_id, messages) ) assert should_terminate is True @@ -282,14 +321,16 @@ async def test_full_interaction_workflow_incorrect(self): ground_truth = "42" # Start interaction - instance_id = await self.interaction.start_interaction(ground_truth=ground_truth) + instance_id = await self.interaction.start_interaction( + ground_truth=ground_truth + ) # Generate response with incorrect answer messages = [{"role": "user", "content": "24"}] with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages + should_terminate, response, reward, metadata = ( + await self.interaction.generate_response(instance_id, messages) ) assert should_terminate is False @@ -300,8 +341,8 @@ async def test_full_interaction_workflow_incorrect(self): messages.append({"role": "user", "content": "42"}) with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages + should_terminate, response, reward, metadata = ( + await self.interaction.generate_response(instance_id, messages) ) assert should_terminate is True @@ -318,8 +359,12 @@ async def test_multiple_concurrent_interactions(self): ground_truth_2 = "24" # Start multiple interactions - instance_id_1 = await self.interaction.start_interaction(ground_truth=ground_truth_1) - instance_id_2 = await self.interaction.start_interaction(ground_truth=ground_truth_2) + instance_id_1 = await self.interaction.start_interaction( + ground_truth=ground_truth_1 + ) + instance_id_2 = await self.interaction.start_interaction( + ground_truth=ground_truth_2 + ) assert len(self.interaction._instance_dict) == 2 assert instance_id_1 in self.interaction._instance_dict @@ -329,9 +374,15 @@ async def test_multiple_concurrent_interactions(self): messages_1 = [{"role": "user", "content": "42"}] messages_2 = [{"role": "user", "content": "24"}] - with patch("verl.utils.reward_score.gsm8k.compute_score", side_effect=[1.0, 1.0]): - should_terminate_1, _, reward_1, _ = await self.interaction.generate_response(instance_id_1, messages_1) - should_terminate_2, _, reward_2, _ = await self.interaction.generate_response(instance_id_2, messages_2) + with patch( + "verl.utils.reward_score.gsm8k.compute_score", side_effect=[1.0, 1.0] + ): + should_terminate_1, _, reward_1, _ = ( + await self.interaction.generate_response(instance_id_1, messages_1) + ) + should_terminate_2, _, reward_2, _ = ( + await self.interaction.generate_response(instance_id_2, messages_2) + ) assert should_terminate_1 is True assert should_terminate_2 is True @@ -351,13 +402,15 @@ async def test_edge_case_empty_messages(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) messages = [] with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages + should_terminate, response, reward, metadata = ( + await self.interaction.generate_response(instance_id, messages) ) assert should_terminate is False @@ -371,15 +424,15 @@ async def test_edge_case_message_without_content(self): ground_truth = "42" # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) + await self.interaction.start_interaction( + instance_id=instance_id, ground_truth=ground_truth + ) - messages = [ - {"role": "user"} # Missing content field - ] + messages = [{"role": "user"}] # Missing content field with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages + should_terminate, response, reward, metadata = ( + await self.interaction.generate_response(instance_id, messages) ) assert should_terminate is False @@ -414,7 +467,9 @@ def test_name_attribute_initialization(self): # Test with default name when not provided in config config_without_name = {} interaction_without_name = Gsm8kInteraction(config_without_name) - assert interaction_without_name.name == "interaction_agent" # Default from BaseInteraction + assert ( + interaction_without_name.name == "interaction_agent" + ) # Default from BaseInteraction # Test that name is accessible as attribute assert hasattr(self.interaction, "name") diff --git a/Agent0/executor_train/verl/tests/interactions/test_interaction_registry.py b/Agent0/executor_train/verl/tests/interactions/test_interaction_registry.py index 7fe193b..e70da36 100644 --- a/Agent0/executor_train/verl/tests/interactions/test_interaction_registry.py +++ b/Agent0/executor_train/verl/tests/interactions/test_interaction_registry.py @@ -35,7 +35,9 @@ def test_get_interaction_class(self): assert base_cls == BaseInteraction # Test getting gsm8k interaction class - gsm8k_cls = get_interaction_class("verl.interactions.gsm8k_interaction.Gsm8kInteraction") + gsm8k_cls = get_interaction_class( + "verl.interactions.gsm8k_interaction.Gsm8kInteraction" + ) assert gsm8k_cls == Gsm8kInteraction def test_initialize_single_interaction_from_config(self): @@ -104,14 +106,21 @@ def test_initialize_multiple_interactions_from_config(self): assert interaction_map["base_agent"].name == "base_agent" # Check custom config was passed - assert interaction_map["base_agent"].config.get("custom_param") == "test_value" + assert ( + interaction_map["base_agent"].config.get("custom_param") == "test_value" + ) finally: os.unlink(temp_config_path) def test_initialize_interaction_without_explicit_name(self): """Test that interaction name is derived from class name when not specified.""" config_content = { - "interaction": [{"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}] + "interaction": [ + { + "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", + "config": {}, + } + ] } with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: @@ -123,7 +132,9 @@ def test_initialize_interaction_without_explicit_name(self): # Check that interaction name was derived from class name assert len(interaction_map) == 1 - assert "gsm8k" in interaction_map # Should be "gsm8k" after removing "interaction" suffix + assert ( + "gsm8k" in interaction_map + ) # Should be "gsm8k" after removing "interaction" suffix assert isinstance(interaction_map["gsm8k"], Gsm8kInteraction) assert interaction_map["gsm8k"].name == "gsm8k" finally: @@ -146,7 +157,13 @@ def test_initialize_empty_config(self): def test_invalid_class_name(self): """Test handling of invalid class name.""" config_content = { - "interaction": [{"name": "invalid", "class_name": "invalid.module.InvalidClass", "config": {}}] + "interaction": [ + { + "name": "invalid", + "class_name": "invalid.module.InvalidClass", + "config": {}, + } + ] } with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: @@ -163,7 +180,11 @@ def test_duplicate_interaction_names(self): """Test handling of duplicate interaction names.""" config_content = { "interaction": [ - {"name": "duplicate", "class_name": "verl.interactions.base.BaseInteraction", "config": {}}, + { + "name": "duplicate", + "class_name": "verl.interactions.base.BaseInteraction", + "config": {}, + }, { "name": "duplicate", "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", @@ -177,7 +198,9 @@ def test_duplicate_interaction_names(self): temp_config_path = f.name try: - with pytest.raises(ValueError, match="Duplicate interaction name 'duplicate' found"): + with pytest.raises( + ValueError, match="Duplicate interaction name 'duplicate' found" + ): initialize_interactions_from_config(temp_config_path) finally: os.unlink(temp_config_path) @@ -187,7 +210,10 @@ def test_auto_name_generation_edge_cases(self): config_content = { "interaction": [ {"class_name": "verl.interactions.base.BaseInteraction", "config": {}}, - {"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}, + { + "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", + "config": {}, + }, ] } diff --git a/Agent0/executor_train/verl/tests/models/test_transformer.py b/Agent0/executor_train/verl/tests/models/test_transformer.py index 111230a..2cecd83 100644 --- a/Agent0/executor_train/verl/tests/models/test_transformer.py +++ b/Agent0/executor_train/verl/tests/models/test_transformer.py @@ -45,10 +45,14 @@ def test_hf_casual_models(): # config = AutoConfig.from_pretrained(test_case) with torch.device("cuda"): model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + config=config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", ) model = model.to(device="cuda") - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") + input_ids = torch.randint( + low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda" + ) attention_mask = create_random_mask( input_ids=input_ids, max_ratio_of_left_padding=0.1, @@ -75,9 +79,14 @@ def test_hf_casual_models(): ).logits # (1, total_nnz, vocab_size) origin_logits = model( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, ).logits - origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask) + origin_logits_rmpad, origin_logits_indices, *_ = unpad_input( + origin_logits, attention_mask + ) logits_rmpad = logits_rmpad.squeeze(0) log_probs = log_probs_from_logits_all_rmpad( @@ -117,10 +126,14 @@ def test_hf_value_models(): config.hidden_dropout = 0 with torch.device("cuda"): model = AutoModelForTokenClassification.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + config=config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", ) model = model.to(device="cuda") - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") + input_ids = torch.randint( + low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda" + ) attention_mask = create_random_mask( input_ids=input_ids, max_ratio_of_left_padding=0.1, @@ -142,7 +155,10 @@ def test_hf_value_models(): ).transpose(0, 1) origin_logits = model( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, ).logits # input with input_ids_rmpad and postition_ids to enable flash attention varlen diff --git a/Agent0/executor_train/verl/tests/models/test_transformers_ulysses.py b/Agent0/executor_train/verl/tests/models/test_transformers_ulysses.py index 233633f..c7a0b65 100644 --- a/Agent0/executor_train/verl/tests/models/test_transformers_ulysses.py +++ b/Agent0/executor_train/verl/tests/models/test_transformers_ulysses.py @@ -20,7 +20,12 @@ import torch.distributed from flash_attn.bert_padding import index_first_axis, rearrange, unpad_input from torch.distributed import init_device_mesh -from transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig, Qwen2Config +from transformers import ( + AutoModelForCausalLM, + LlamaConfig, + PretrainedConfig, + Qwen2Config, +) from verl.models.transformers.monkey_patch import apply_monkey_patch from verl.protocol import DataProto @@ -48,23 +53,45 @@ class SequenceParallelConfig: def test_configs(): return [ SequenceParallelConfig( - LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True + LlamaConfig( + num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32 + ), + sp_size=8, + is_valid=True, ), SequenceParallelConfig( - Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584), + Qwen2Config( + num_hidden_layers=2, + num_attention_heads=28, + num_key_value_heads=4, + hidden_size=3584, + ), sp_size=4, is_valid=True, ), SequenceParallelConfig( - Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584), + Qwen2Config( + num_hidden_layers=2, + num_attention_heads=28, + num_key_value_heads=4, + hidden_size=3584, + ), sp_size=8, is_valid=False, ), SequenceParallelConfig( - Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=4, is_valid=True + Qwen2Config( + num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4 + ), + sp_size=4, + is_valid=True, ), SequenceParallelConfig( - Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=8, is_valid=True + Qwen2Config( + num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4 + ), + sp_size=8, + is_valid=True, ), ] @@ -80,10 +107,16 @@ def test_hf_casual_fwd_bwd(test_config): if not torch.distributed.is_initialized(): initialize_global_process_group() - context = contextlib.nullcontext() if test_config.is_valid else pytest.raises(AssertionError) + context = ( + contextlib.nullcontext() + if test_config.is_valid + else pytest.raises(AssertionError) + ) with context: world_size = torch.distributed.get_world_size() - _hf_casual_fwd_bwd(test_config.config, test_config.sp_size, world_size // test_config.sp_size) + _hf_casual_fwd_bwd( + test_config.config, test_config.sp_size, world_size // test_config.sp_size + ) # TODO: seems not work, will cause `socketStartConnect: Connect to xxx failed : Software caused connection abort` # torch.distributed.destroy_process_group() @@ -104,16 +137,23 @@ def _hf_casual_fwd(config, sp_size, dp_size): # patch before load with torch.device("cuda"): model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + config=config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", ) apply_monkey_patch(model, sp_size) model = model.to(device="cuda") sync_model_parameters_global(model) # different rank will generate different input_ids following fsdp - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") + input_ids = torch.randint( + low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda" + ) attention_mask = create_random_mask( - input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8 + input_ids=input_ids, + max_ratio_of_left_padding=0, + max_ratio_of_valid_token=0.9, + min_ratio_of_valid_token=0.8, ) position_ids = compute_position_id_with_mask( attention_mask @@ -145,17 +185,25 @@ def _hf_casual_fwd(config, sp_size, dp_size): # slice input tensor for ulysses # input_ids are padded and sliced # postition_ids are only padded but not sliced - input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ( + ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad, + sp_size=get_ulysses_sequence_parallel_world_size(), + ) ) # input with input_ids_rmpad and postition_ids to enable flash attention varlen logits_split_in_seq = model( - input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False + input_ids_rmpad_sliced, + position_ids=position_ids_rmpad_padded, + use_cache=False, ).logits # (1, total_nnz/n, vocab_size) # all_gather output - logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size) + logits_full = gather_outpus_and_unpad( + logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size + ) # 2. perform normal forward set_ulysses_sequence_parallel_group(None) @@ -183,16 +231,23 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size): # patch before load with torch.device("cuda"): model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + config=config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", ) apply_monkey_patch(model, sp_size) model = model.to(device="cuda") sync_model_parameters_global(model) # different rank will generate different input_ids following fsdp - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") + input_ids = torch.randint( + low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda" + ) attention_mask = create_random_mask( - input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8 + input_ids=input_ids, + max_ratio_of_left_padding=0, + max_ratio_of_valid_token=0.9, + min_ratio_of_valid_token=0.8, ) position_ids = compute_position_id_with_mask( attention_mask @@ -224,17 +279,25 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size): # slice input tensor for ulysses # input_ids are padded and sliced # postition_ids are only padded but not sliced - input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ( + ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad, + sp_size=get_ulysses_sequence_parallel_world_size(), + ) ) # input with input_ids_rmpad and postition_ids to enable flash attention varlen logits_split_in_seq = model( - input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False + input_ids_rmpad_sliced, + position_ids=position_ids_rmpad_padded, + use_cache=False, ).logits # (1, total_nnz/n, vocab_size) # all_gather output - logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size) + logits_full = gather_outpus_and_unpad( + logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size + ) # 2. perform normal forward set_ulysses_sequence_parallel_group(None) diff --git a/Agent0/executor_train/verl/tests/single_controller/check_worker_alive/main.py b/Agent0/executor_train/verl/tests/single_controller/check_worker_alive/main.py index cbdee9a..67d65e5 100644 --- a/Agent0/executor_train/verl/tests/single_controller/check_worker_alive/main.py +++ b/Agent0/executor_train/verl/tests/single_controller/check_worker_alive/main.py @@ -20,7 +20,11 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) @ray.remote diff --git a/Agent0/executor_train/verl/tests/single_controller/detached_worker/client.py b/Agent0/executor_train/verl/tests/single_controller/detached_worker/client.py index 52f2c72..d80af70 100644 --- a/Agent0/executor_train/verl/tests/single_controller/detached_worker/client.py +++ b/Agent0/executor_train/verl/tests/single_controller/detached_worker/client.py @@ -42,13 +42,23 @@ def compute_position_id_with_mask(mask): sequence_length = 1024 # give Trainer some data to train - input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda") + input_ids = torch.randint( + low=0, + high=256, + size=(batch_size, sequence_length), + dtype=torch.int64, + device="cuda", + ) attention_mask = torch.ones_like(input_ids) position_ids = compute_position_id_with_mask(attention_mask) data = DataProto( batch=TensorDict( - {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}, + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + }, batch_size=batch_size, ), meta_info={}, diff --git a/Agent0/executor_train/verl/tests/single_controller/detached_worker/server.py b/Agent0/executor_train/verl/tests/single_controller/detached_worker/server.py index 57e555a..7745856 100644 --- a/Agent0/executor_train/verl/tests/single_controller/detached_worker/server.py +++ b/Agent0/executor_train/verl/tests/single_controller/detached_worker/server.py @@ -38,7 +38,11 @@ from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.utils.megatron.optimizer import get_megatron_optimizer -from verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config +from verl.utils.megatron_utils import ( + get_model, + init_megatron_optim_config, + mcore_model_parallel_config, +) @ray.remote @@ -75,7 +79,9 @@ def init_model(self): num_key_value_heads=16, ) - megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16) + megatron_config = mcore_model_parallel_config( + sequence_parallel=True, params_dtype=torch.bfloat16 + ) self.megatron_config = megatron_config def megatron_actor_model_provider(pre_process, post_process): @@ -102,7 +108,9 @@ def megatron_actor_model_provider(pre_process, post_process): optim_config = init_megatron_optim_config(optim_config) self.optimizer_config = optim_config - actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config) + actor_optimizer = get_megatron_optimizer( + model=actor_module, config=optim_config + ) self.model = actor_module[0] self.optimizer = actor_optimizer @@ -118,14 +126,20 @@ def train_model(self, data: DataProto) -> DataProto: zero_buffer=(not self.optimizer_config.use_distributed_optimizer) ) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm # update for 1 iteration - output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits + output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ).logits output.mean().backward() update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step( self.megatron_config, self.megatron_config.timers ) - return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0])) + return DataProto( + batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0]) + ) if __name__ == "__main__": diff --git a/Agent0/executor_train/verl/tests/single_controller/test_auto_padding_on_cpu.py b/Agent0/executor_train/verl/tests/single_controller/test_auto_padding_on_cpu.py index f2c4412..fdfdbf0 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_auto_padding_on_cpu.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_auto_padding_on_cpu.py @@ -20,7 +20,11 @@ from verl.protocol import DataProtoConfig from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) # or set env var VERL_AUTO_PADDING = "1" / "true" DataProtoConfig.auto_padding = True @@ -47,16 +51,24 @@ def test_auto_padding(): # test locally first for test_size in range(4, 20): - local_data = DataProto.from_dict({"a": torch.zeros(test_size)}, {"na": np.zeros(test_size, dtype=object)}) + local_data = DataProto.from_dict( + {"a": torch.zeros(test_size)}, {"na": np.zeros(test_size, dtype=object)} + ) # print(f"before padding, local_data = {local_data}") - padding_size = (chunk_size - (test_size % chunk_size)) if (test_size % chunk_size > 0) else 0 + padding_size = ( + (chunk_size - (test_size % chunk_size)) + if (test_size % chunk_size > 0) + else 0 + ) local_data.padding(padding_size) # print(f"after padding, local_data = {local_data}") - assert len(local_data) == len(local_data) + len(local_data) % chunk_size, ( - f"expecting padded length to be {len(local_data) + len(local_data) % chunk_size}, but got {len(local_data)}" - ) + assert ( + len(local_data) == len(local_data) + len(local_data) % chunk_size + ), f"expecting padded length to be {len(local_data) + len(local_data) % chunk_size}, but got {len(local_data)}" chunked = local_data.chunk(chunk_size) - assert len(chunked) == chunk_size, f"during test_size = {test_size}, expecting {chunk_size}, got {chunked}" + assert ( + len(chunked) == chunk_size + ), f"during test_size = {test_size}, expecting {chunk_size}, got {chunked}" for dp in chunked: assert len(dp) == test_size // chunk_size + bool(test_size % chunk_size), ( f"test size = {test_size}, expecting dp to be length of " @@ -64,19 +76,28 @@ def test_auto_padding(): ) # test with RayWorkerGroup method decorated as dispatch_mode=Dispatch.DP_COMPUTE_PROTO - data = DataProto.from_dict({"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}) + data = DataProto.from_dict( + {"a": torch.zeros(10)}, + {"na": np.array([str(i) for i in range(10)], dtype=object)}, + ) output = actor_wg.add(data) print(output.batch["a"]) assert len(output) == 10 - data = DataProto.from_dict({"a": torch.zeros(1)}, {"na": np.array([str(i) for i in range(1)], dtype=object)}) + data = DataProto.from_dict( + {"a": torch.zeros(1)}, + {"na": np.array([str(i) for i in range(1)], dtype=object)}, + ) output = actor_wg.add(data) print(output.batch["a"]) assert len(output) == 1 - data = DataProto.from_dict({"a": torch.zeros(8)}, {"na": np.array([str(i) for i in range(8)], dtype=object)}) + data = DataProto.from_dict( + {"a": torch.zeros(8)}, + {"na": np.array([str(i) for i in range(8)], dtype=object)}, + ) output = actor_wg.add(data) print(output.batch["a"]) @@ -86,21 +107,26 @@ def test_auto_padding(): DataProtoConfig.auto_padding = False data = DataProto.from_dict( - {"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True + {"a": torch.zeros(10)}, + {"na": np.array([str(i) for i in range(10)], dtype=object)}, + auto_padding=True, ) output = actor_wg.add(data) print(output.batch["a"]) assert len(output) == 10 data = DataProto.from_single_dict( - {"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True + {"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, + auto_padding=True, ) output = actor_wg.add(data) print(output.batch["a"]) assert len(output) == 1 - data = DataProto.from_single_dict({"a": torch.zeros(8), "na": np.array([str(i) for i in range(8)], dtype=object)}) + data = DataProto.from_single_dict( + {"a": torch.zeros(8), "na": np.array([str(i) for i in range(8)], dtype=object)} + ) output = actor_wg.add(data) print(output.batch["a"]) diff --git a/Agent0/executor_train/verl/tests/single_controller/test_colocated_workers.py b/Agent0/executor_train/verl/tests/single_controller/test_colocated_workers.py index cdaa747..809ff9a 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_colocated_workers.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_colocated_workers.py @@ -60,7 +60,9 @@ def test_colocated_workers(): resource_pool = RayResourcePool(process_on_nodes=[2]) actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls) - critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls) + critic_wg = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=critic_cls + ) expected_actor_output = actor_wg.add(data) expected_critic_output = critic_wg.sub(data) @@ -68,7 +70,9 @@ def test_colocated_workers(): # create colocated workers cls_dict = {"actor": actor_cls, "critic": critic_cls} ray_cls_with_init = create_colocated_worker_cls(cls_dict) - wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) + wg_dict = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init + ) spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys()) colocated_actor_wg = spawn_wg["actor"] @@ -77,7 +81,11 @@ def test_colocated_workers(): actor_output = colocated_actor_wg.add(data) critic_output = colocated_critic_wg.sub(data) - torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0) - torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0) + torch.testing.assert_close( + expected_actor_output.batch, actor_output.batch, atol=0, rtol=0 + ) + torch.testing.assert_close( + expected_critic_output.batch, critic_output.batch, atol=0, rtol=0 + ) ray.shutdown() diff --git a/Agent0/executor_train/verl/tests/single_controller/test_colocated_workers_fused.py b/Agent0/executor_train/verl/tests/single_controller/test_colocated_workers_fused.py index 93b1a72..b89586b 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_colocated_workers_fused.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_colocated_workers_fused.py @@ -60,7 +60,9 @@ def test_colocated_workers_fused(): resource_pool = RayResourcePool(process_on_nodes=[2]) actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls) - critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls) + critic_wg = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=critic_cls + ) expected_actor_output = actor_wg.add(data) expected_critic_output = critic_wg.sub(data) @@ -68,7 +70,9 @@ def test_colocated_workers_fused(): # create colocated workers cls_dict = {"actor": actor_cls, "critic": critic_cls} ray_cls_with_init = create_colocated_worker_cls_fused(cls_dict) - wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) + wg_dict = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init + ) spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys()) colocated_actor_wg = spawn_wg["actor"] @@ -77,7 +81,11 @@ def test_colocated_workers_fused(): actor_output = colocated_actor_wg.add(data) critic_output = colocated_critic_wg.sub(data) - torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0) - torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0) + torch.testing.assert_close( + expected_actor_output.batch, actor_output.batch, atol=0, rtol=0 + ) + torch.testing.assert_close( + expected_critic_output.batch, critic_output.batch, atol=0, rtol=0 + ) ray.shutdown() diff --git a/Agent0/executor_train/verl/tests/single_controller/test_data_transfer.py b/Agent0/executor_train/verl/tests/single_controller/test_data_transfer.py index 13777b0..5095b03 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_data_transfer.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_data_transfer.py @@ -24,7 +24,11 @@ from verl import DataProto from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) from verl.utils.ray_utils import parallel_put @@ -98,7 +102,9 @@ def test_data_transfer(): for input_data, output_data in zip(data_list, output_lst, strict=True): for key in input_data.batch.keys(): - assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), ( + assert torch.all( + torch.eq(input_data.batch[key] + 1, output_data.batch[key]) + ), ( input_data.batch[key], output_data.batch[key], key, diff --git a/Agent0/executor_train/verl/tests/single_controller/test_decorator_on_cpu.py b/Agent0/executor_train/verl/tests/single_controller/test_decorator_on_cpu.py index 4dfec63..e0d0511 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_decorator_on_cpu.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_decorator_on_cpu.py @@ -23,7 +23,11 @@ from verl.protocol import DataProto, DataProtoFuture from verl.single_controller.base.decorator import Dispatch, register from verl.single_controller.base.worker import Worker -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) # Pytest fixture for Ray setup/teardown @@ -47,7 +51,11 @@ def __init__(self, initial_value=0): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def dp_compute(self, data: DataProto) -> DataProto: time.sleep(0.1) # Simulate work - rank_value = torch.tensor(self.rank, device=data.batch["input"].device, dtype=data.batch["input"].dtype) + rank_value = torch.tensor( + self.rank, + device=data.batch["input"].device, + dtype=data.batch["input"].dtype, + ) data.batch["output"] = data.batch["input"] + self.value + rank_value return data @@ -56,7 +64,11 @@ def dp_compute(self, data: DataProto) -> DataProto: async def async_dp_compute(self, data: DataProto) -> DataProto: # Simulate async work await asyncio.sleep(0.1) # Simulate async work - rank_value = torch.tensor(self.rank, device=data.batch["input"].device, dtype=data.batch["input"].dtype) + rank_value = torch.tensor( + self.rank, + device=data.batch["input"].device, + dtype=data.batch["input"].dtype, + ) data.batch["output_async"] = data.batch["input"] * 2 + self.value + rank_value return data @@ -68,10 +80,14 @@ def test_decorator_dp_compute(ray_init_shutdown): Verifies the result correctness. """ num_workers = 2 - resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) # Use CPU for simplicity + resource_pool = RayResourcePool( + [num_workers], use_gpu=False, max_colocate_count=1 + ) # Use CPU for simplicity cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=10) worker_group = RayWorkerGroup( - resource_pool, cls_with_args, name_prefix=f"decorator_test_sync_dp_{int(time.time())}" + resource_pool, + cls_with_args, + name_prefix=f"decorator_test_sync_dp_{int(time.time())}", ) # Prepare input data (size 4, for 2 workers) @@ -94,7 +110,11 @@ def test_decorator_dp_compute(ray_init_shutdown): expected_output_part2 = torch.tensor([2, 3], dtype=torch.float32) + 10 + 1 expected_output = torch.cat([expected_output_part1, expected_output_part2]) - torch.testing.assert_close(output.batch["output"], expected_output, msg="Sync DP compute output data mismatch") + torch.testing.assert_close( + output.batch["output"], + expected_output, + msg="Sync DP compute output data mismatch", + ) # Test function for async def method with DP compute @@ -107,7 +127,9 @@ def test_decorator_async_function(ray_init_shutdown): resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=5) worker_group = RayWorkerGroup( - resource_pool, cls_with_args, name_prefix=f"decorator_test_async_dp_{int(time.time())}" + resource_pool, + cls_with_args, + name_prefix=f"decorator_test_async_dp_{int(time.time())}", ) # Prepare input data (size 4, for 2 workers) @@ -118,7 +140,9 @@ def test_decorator_async_function(ray_init_shutdown): future_output: DataProtoFuture = worker_group.async_dp_compute(data) # Assert that the call returned a future - assert isinstance(future_output, DataProtoFuture), "Expected DataProtoFuture for async def call" + assert isinstance( + future_output, DataProtoFuture + ), "Expected DataProtoFuture for async def call" # Get the result (this should block) result_data = future_output.get() @@ -137,5 +161,7 @@ def test_decorator_async_function(ray_init_shutdown): expected_output = torch.cat([expected_output_part1, expected_output_part2]) torch.testing.assert_close( - result_data.batch["output_async"], expected_output, msg="Async DP compute output data mismatch" + result_data.batch["output_async"], + expected_output, + msg="Async DP compute output data mismatch", ) diff --git a/Agent0/executor_train/verl/tests/single_controller/test_driverfunc_to_worker.py b/Agent0/executor_train/verl/tests/single_controller/test_driverfunc_to_worker.py index a38d790..23482da 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_driverfunc_to_worker.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_driverfunc_to_worker.py @@ -45,7 +45,8 @@ def get_aux_metrics(self, test_proto): decode_count.append(len(sequence_ids[i].tolist())) ret_proto = DataProto( batch=TensorDict( - {"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0) + {"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, + batch_size=sequence_ids.size(0), ) ) return ret_proto @@ -79,6 +80,8 @@ def test(): hs = HackSelf() ret_proto2 = get_aux_metrics(hs, test_proto) - torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"]) + torch.testing.assert_close( + ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"] + ) ray.shutdown() diff --git a/Agent0/executor_train/verl/tests/single_controller/test_fused_workers_on_cpu.py b/Agent0/executor_train/verl/tests/single_controller/test_fused_workers_on_cpu.py index 527ddc1..35f2e89 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_fused_workers_on_cpu.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_fused_workers_on_cpu.py @@ -71,7 +71,9 @@ def test_fused_workers(): hybrid_cls_with_init = RayClassWithInitArgs(cls=HybridWorker) hybrid_cls_with_init.fused_worker_used = True - fused_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=hybrid_cls_with_init) + fused_wg = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=hybrid_cls_with_init + ) fused_wg.fuse(cls_dict.keys()) x = fused_wg.actor.add(0.1) diff --git a/Agent0/executor_train/verl/tests/single_controller/test_high_level_scheduling_api.py b/Agent0/executor_train/verl/tests/single_controller/test_high_level_scheduling_api.py index 52cc7c7..c326b6d 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_high_level_scheduling_api.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_high_level_scheduling_api.py @@ -17,7 +17,12 @@ import ray from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool +from verl.single_controller.ray.base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, + merge_resource_pool, +) @ray.remote @@ -40,18 +45,34 @@ def test(): class_with_args = RayClassWithInitArgs(cls=TestActor) print("create actor worker group") - actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_actor") + actor_wg = RayWorkerGroup( + resource_pool, class_with_args, name_prefix="high_level_api_actor" + ) print("create critic worker group") - critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="hight_level_api_critic") + critic_wg = RayWorkerGroup( + resource_pool, class_with_args, name_prefix="hight_level_api_critic" + ) print("create rm worker group") - rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_rm") + rm_wg = RayWorkerGroup( + resource_pool, class_with_args, name_prefix="high_level_api_rm" + ) print("create ref worker group") - ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_ref") - - assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] + ref_wg = RayWorkerGroup( + resource_pool, class_with_args, name_prefix="high_level_api_ref" + ) + + assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [ + str(i) for i in range(8) + ] + assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [ + str(i) for i in range(8) + ] + assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [ + str(i) for i in range(8) + ] + assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [ + str(i) for i in range(8) + ] del actor_wg del critic_wg @@ -72,14 +93,30 @@ def test(): assert ref_resource_pool.world_size == 4 assert total_resource_pool.world_size == 8 - actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_actor") - critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_critic") - rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix="high_level_api_rm") - ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix="high_level_api_ref") - - assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)] - assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)] + actor_wg = RayWorkerGroup( + total_resource_pool, class_with_args, name_prefix="high_level_api_actor" + ) + critic_wg = RayWorkerGroup( + total_resource_pool, class_with_args, name_prefix="high_level_api_critic" + ) + rm_wg = RayWorkerGroup( + rm_resource_pool, class_with_args, name_prefix="high_level_api_rm" + ) + ref_wg = RayWorkerGroup( + ref_resource_pool, class_with_args, name_prefix="high_level_api_ref" + ) + + assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [ + str(i) for i in range(8) + ] + assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [ + str(i) for i in range(8) + ] + assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [ + str(i) for i in range(4) + ] + assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [ + str(i) for i in range(4, 8) + ] ray.shutdown() diff --git a/Agent0/executor_train/verl/tests/single_controller/test_ray_collectives.py b/Agent0/executor_train/verl/tests/single_controller/test_ray_collectives.py index 3722a8f..a300e2d 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_ray_collectives.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_ray_collectives.py @@ -26,7 +26,11 @@ from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) @ray.remote @@ -35,7 +39,9 @@ class Actor(Worker): def init(self): remote_rank = self.rank // 2 self.group_name = f"A{self.rank}_R{remote_rank}" - collective.init_collective_group(world_size=2, rank=0, backend="nccl", group_name=self.group_name) + collective.init_collective_group( + world_size=2, rank=0, backend="nccl", group_name=self.group_name + ) @register(Dispatch.ONE_TO_ALL, blocking=False) def send_tensors(self): @@ -52,8 +58,12 @@ def init(self): self.first_group_name = f"A{self.remote_first_rank}_R{self.rank}" self.second_group_name = f"A{self.remote_second_rank}_R{self.rank}" - collective.init_collective_group(world_size=2, rank=1, backend="nccl", group_name=self.first_group_name) - collective.init_collective_group(world_size=2, rank=1, backend="nccl", group_name=self.second_group_name) + collective.init_collective_group( + world_size=2, rank=1, backend="nccl", group_name=self.first_group_name + ) + collective.init_collective_group( + world_size=2, rank=1, backend="nccl", group_name=self.second_group_name + ) @register(Dispatch.ONE_TO_ALL, blocking=False) def receive_tensors(self): @@ -65,7 +75,10 @@ def receive_tensors(self): @register(Dispatch.ONE_TO_ALL) def get_tensors(self): - return {f"src_{self.remote_first_rank}": self.tensor1, f"src_{self.remote_second_rank}": self.tensor2} + return { + f"src_{self.remote_first_rank}": self.tensor1, + f"src_{self.remote_second_rank}": self.tensor2, + } def test_ray_collective_group(): @@ -78,10 +91,14 @@ def test_ray_collective_group(): rollout_cls = RayClassWithInitArgs(cls=Rollout) actor_wg = RayWorkerGroup( - resource_pool=actor_resource_pool, ray_cls_with_init=actor_cls, name_prefix="collective_group_actor" + resource_pool=actor_resource_pool, + ray_cls_with_init=actor_cls, + name_prefix="collective_group_actor", ) rollout_wg = RayWorkerGroup( - resource_pool=rollout_resource_pool, ray_cls_with_init=rollout_cls, name_prefix="collective_group_rollout" + resource_pool=rollout_resource_pool, + ray_cls_with_init=rollout_cls, + name_prefix="collective_group_rollout", ) actor_wg.init() diff --git a/Agent0/executor_train/verl/tests/single_controller/test_ray_local_envs_on_cpu.py b/Agent0/executor_train/verl/tests/single_controller/test_ray_local_envs_on_cpu.py index ee6c0cb..945df86 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_ray_local_envs_on_cpu.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_ray_local_envs_on_cpu.py @@ -20,7 +20,11 @@ import ray from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) @ray.remote @@ -41,7 +45,9 @@ def test_basics(): class_with_args = RayClassWithInitArgs(cls=TestActor) worker_group = RayWorkerGroup( - resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic" + resource_pool=resource_pool, + ray_cls_with_init=class_with_args, + name_prefix="worker_group_basic", ) output = worker_group.execute_all_sync("getenv", key="RAY_LOCAL_WORLD_SIZE") diff --git a/Agent0/executor_train/verl/tests/single_controller/test_rvdz.py b/Agent0/executor_train/verl/tests/single_controller/test_rvdz.py index 7dea12f..5736a89 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_rvdz.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_rvdz.py @@ -26,7 +26,9 @@ def __init__(self, rank, world_size, group_name): def init(self): from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray - self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name) + self.communicator = create_nccl_communicator_in_ray( + self.rank, self.world_size, self.group_name + ) def test(self): if self.communicator is None: @@ -40,7 +42,10 @@ def test_rvdz(): group_name = "test_group" world_size = 2 - workers = [TestWorker.options(num_gpus=1).remote(rank, world_size, group_name) for rank in range(world_size)] + workers = [ + TestWorker.options(num_gpus=1).remote(rank, world_size, group_name) + for rank in range(world_size) + ] ray.get([worker.init.remote() for worker in workers]) diff --git a/Agent0/executor_train/verl/tests/single_controller/test_worker_group_basics.py b/Agent0/executor_train/verl/tests/single_controller/test_worker_group_basics.py index 5c4823d..854d164 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_worker_group_basics.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_worker_group_basics.py @@ -18,9 +18,18 @@ import ray import torch -from verl.single_controller.base.decorator import Dispatch, Execute, collect_all_to_all, register +from verl.single_controller.base.decorator import ( + Dispatch, + Execute, + collect_all_to_all, + register, +) from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) def two_to_all_dispatch_fn(worker_group, *args, **kwargs): @@ -60,7 +69,12 @@ def foo_one_to_all(self, x, y): def foo_all_to_all(self, x, y): return self._x + y + x - @register(dispatch_mode={"dispatch_fn": two_to_all_dispatch_fn, "collect_fn": collect_all_to_all}) + @register( + dispatch_mode={ + "dispatch_fn": two_to_all_dispatch_fn, + "collect_fn": collect_all_to_all, + } + ) def foo_custom(self, x, y): return self._x + y + x @@ -97,7 +111,9 @@ def test_basics(): class_with_args = RayClassWithInitArgs(cls=TestActor, x=2) worker_group = RayWorkerGroup( - resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic" + resource_pool=resource_pool, + ray_cls_with_init=class_with_args, + name_prefix="worker_group_basic", ) print(worker_group.worker_names) diff --git a/Agent0/executor_train/verl/tests/single_controller/test_worker_group_torch.py b/Agent0/executor_train/verl/tests/single_controller/test_worker_group_torch.py index a601c43..7db37ff 100644 --- a/Agent0/executor_train/verl/tests/single_controller/test_worker_group_torch.py +++ b/Agent0/executor_train/verl/tests/single_controller/test_worker_group_torch.py @@ -22,7 +22,11 @@ import torch.distributed from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) @ray.remote @@ -39,7 +43,9 @@ def init(self): def all_gather(self): world_size = self._world_size output = torch.zeros( - size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device + size=(self.tensor.shape[0] * world_size,), + dtype=self.tensor.dtype, + device=self.tensor.device, ) torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False) return output @@ -58,7 +64,9 @@ def __init__(self, size) -> None: def all_gather(self): world_size = self._world_size output = torch.zeros( - size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device + size=(self.tensor.shape[0] * world_size,), + dtype=self.tensor.dtype, + device=self.tensor.device, ) torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False) return output @@ -74,7 +82,9 @@ def test_all_gather_torch(): resource_pool = RayResourcePool([4], use_gpu=True) class_with_args = RayClassWithInitArgs(cls=TestAllGatherActor, size=2) - worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch") + worker_group = RayWorkerGroup( + resource_pool, class_with_args, name_prefix="worker_group_torch" + ) worker_group.execute_all_sync("init") output = worker_group.execute_all_sync("all_gather") @@ -83,7 +93,9 @@ def test_all_gather_torch(): output = output[0].cpu() print(output) - assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64)) + assert torch.all( + output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64) + ) ray.shutdown() @@ -98,7 +110,9 @@ def test_all_gather_torch_v2(): resource_pool = RayResourcePool([4], use_gpu=True) class_with_args = RayClassWithInitArgs(cls=TestAllGatherActorV2, size=2) - worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch") + worker_group = RayWorkerGroup( + resource_pool, class_with_args, name_prefix="worker_group_torch" + ) output = worker_group.execute_all_sync("all_gather") for i in range(1, len(output)): @@ -106,6 +120,8 @@ def test_all_gather_torch_v2(): output = output[0].cpu() print(output) - assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64)) + assert torch.all( + output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64) + ) ray.shutdown() diff --git a/Agent0/executor_train/verl/tests/special_distributed/test_fsdp_ckpt.py b/Agent0/executor_train/verl/tests/special_distributed/test_fsdp_ckpt.py index 49dceb7..e6431dd 100644 --- a/Agent0/executor_train/verl/tests/special_distributed/test_fsdp_ckpt.py +++ b/Agent0/executor_train/verl/tests/special_distributed/test_fsdp_ckpt.py @@ -30,21 +30,27 @@ def test_fsdp_ckpt(strategy="fsdp"): assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) + device_mesh = init_device_mesh( + "cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",) + ) model_name = "Qwen/Qwen2.5-0.5B-Instruct" config = Qwen2Config(num_hidden_layers=1) with torch.device("cuda"): model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + config=config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", ) model = model.to(device="cuda") # Wrap model with FSDP if strategy == "fsdp": mixed_precision = MixedPrecision( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, ) model = FSDP( @@ -57,7 +63,9 @@ def test_fsdp_ckpt(strategy="fsdp"): ) else: mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + cast_forward_inputs=True, ) fsdp_kwargs = { "mesh": device_mesh, @@ -97,7 +105,9 @@ def test_fsdp_ckpt(strategy="fsdp"): # Save checkpoint after first update temp_dir = tempfile.mkdtemp() checkpoint_path = os.path.join(temp_dir, "checkpoint") - checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0) + checkpoint_manager.save_checkpoint( + local_path=checkpoint_path, hdfs_path=None, global_step=0 + ) # Step 2: Second update and forward pass outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2) @@ -109,7 +119,9 @@ def test_fsdp_ckpt(strategy="fsdp"): # Record logits after second update with torch.no_grad(): - logits_before_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits + logits_before_load = model( + input_ids=input_ids2, attention_mask=attention_mask2 + ).logits # Step 3: Load checkpoint and repeat second update checkpoint_manager.load_checkpoint(checkpoint_path) @@ -124,10 +136,14 @@ def test_fsdp_ckpt(strategy="fsdp"): # Record logits after loaded checkpoint and update with torch.no_grad(): - logits_after_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits + logits_after_load = model( + input_ids=input_ids2, attention_mask=attention_mask2 + ).logits # Step 4: Verify outputs match - torch.testing.assert_close(logits_before_load, logits_after_load, atol=0.0, rtol=0.0) + torch.testing.assert_close( + logits_before_load, logits_after_load, atol=0.0, rtol=0.0 + ) print("Checkpoint save/load test passed!") # Cleanup diff --git a/Agent0/executor_train/verl/tests/special_distributed/test_tensor_dict.py b/Agent0/executor_train/verl/tests/special_distributed/test_tensor_dict.py index 0a7f803..b260b89 100644 --- a/Agent0/executor_train/verl/tests/special_distributed/test_tensor_dict.py +++ b/Agent0/executor_train/verl/tests/special_distributed/test_tensor_dict.py @@ -25,15 +25,23 @@ def test_all_gather_data_proto(): - device_mesh = torch.distributed.device_mesh.init_device_mesh("cuda", mesh_shape=[2, 2], mesh_dim_names=["dp", "tp"]) + device_mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", mesh_shape=[2, 2], mesh_dim_names=["dp", "tp"] + ) global_rank = torch.distributed.get_rank() - obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]]) + obs = torch.tensor( + [[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]] + ) labels = ["a", "b"] if global_rank % 2 == 0 else ["b", "a"] labels = np.array(labels, dtype=object) - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) + data = DataProto.from_dict( + tensors={"obs": obs}, + non_tensors={"labels": labels}, + meta_info={"info": "test_info"}, + ) all_gather_data_proto(data=data, process_group=device_mesh.get_group("dp")) @@ -63,22 +71,36 @@ def test_vocab_parallel_entropy(): from verl.utils.torch_functional import entropy_from_logits mpu.initialize_model_parallel( - tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None + tensor_model_parallel_size=2, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, ) batch_size = 2 seqlen = 128 vocab_size = 155136 - logits = torch.randn(batch_size * seqlen, vocab_size, device="cuda", requires_grad=True) - target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device="cuda", dtype=torch.int64) + logits = torch.randn( + batch_size * seqlen, vocab_size, device="cuda", requires_grad=True + ) + target = torch.randint( + low=0, + high=vocab_size, + size=(batch_size * seqlen,), + device="cuda", + dtype=torch.int64, + ) # broadcast across tp torch.distributed.broadcast( - logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() + logits, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group(), ) torch.distributed.broadcast( - target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() + target, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group(), ) tp_rank = mpu.get_tensor_model_parallel_rank() @@ -86,7 +108,9 @@ def test_vocab_parallel_entropy(): # get the local logits of each tp vocab_parallel_logits = ( - logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_() + logits.clone() + .detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp] + .requires_grad_() ) logits.grad = None vocab_parallel_logits.grad = None @@ -102,11 +126,13 @@ def test_vocab_parallel_entropy(): torch.testing.assert_close(output_entropy, target_entropy) target_entropy.backward(grad_output) torch.testing.assert_close( - logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad + logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], + vocab_parallel_logits.grad, ) # make sure logits is not altered torch.testing.assert_close( - logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits + logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], + vocab_parallel_logits, ) if mpu.get_tensor_model_parallel_rank() == 0: diff --git a/Agent0/executor_train/verl/tests/special_e2e/check_custom_rwd_fn.py b/Agent0/executor_train/verl/tests/special_e2e/check_custom_rwd_fn.py index 8d77a53..c1cc631 100644 --- a/Agent0/executor_train/verl/tests/special_e2e/check_custom_rwd_fn.py +++ b/Agent0/executor_train/verl/tests/special_e2e/check_custom_rwd_fn.py @@ -19,8 +19,12 @@ def check_congratulations_in_file(output_file): with open(output_file) as f: output = f.read() - success_message = "Congratulations!!! You have called my_reward_function successfully!!!" - assert success_message in output, f"Success message of my_reward_function not found in {output_file}" + success_message = ( + "Congratulations!!! You have called my_reward_function successfully!!!" + ) + assert ( + success_message in output + ), f"Success message of my_reward_function not found in {output_file}" print("Check passes") diff --git a/Agent0/executor_train/verl/tests/special_e2e/check_results.py b/Agent0/executor_train/verl/tests/special_e2e/check_results.py index 9453282..f189d36 100644 --- a/Agent0/executor_train/verl/tests/special_e2e/check_results.py +++ b/Agent0/executor_train/verl/tests/special_e2e/check_results.py @@ -49,5 +49,7 @@ def extract_reward_from_line(line): best_reward = reward print(f"Best reward is {best_reward}") - assert best_reward > args.target, f"Best reward must be greater than {args.target}. best_reward: {best_reward}" + assert ( + best_reward > args.target + ), f"Best reward must be greater than {args.target}. best_reward: {best_reward}" print("Check passes") diff --git a/Agent0/executor_train/verl/tests/special_e2e/envs/digit_completion/task.py b/Agent0/executor_train/verl/tests/special_e2e/envs/digit_completion/task.py index c3643a8..54c1658 100644 --- a/Agent0/executor_train/verl/tests/special_e2e/envs/digit_completion/task.py +++ b/Agent0/executor_train/verl/tests/special_e2e/envs/digit_completion/task.py @@ -32,7 +32,9 @@ class DigitCompletion: Note that the tokenizer is char-level to increase the difficulty. """ - def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, seed=0): + def __init__( + self, max_number: int, max_diff: int, max_num_in_response: int, seed=0 + ): """ Args: @@ -49,7 +51,9 @@ def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, see assert self.max_diff > 0 self.max_number_length = len(str(max_number)) # {num1},{num2}:{max_num_in_response},{max_number} - self._prompt_length = self.max_number_length * 2 + 4 + self.max_number_length # no negative is allowed + self._prompt_length = ( + self.max_number_length * 2 + 4 + self.max_number_length + ) # no negative is allowed self.np_rng = np.random.default_rng(seed=seed) @@ -75,7 +79,11 @@ def prompt_length(self): def response_length(self): # number length + comma length + [EOS] # The actual number times 1.5 to allow 'U' - return (self.max_num_in_response * self.max_number_length + (self.max_num_in_response - 1) + 1) * 2 + return ( + self.max_num_in_response * self.max_number_length + + (self.max_num_in_response - 1) + + 1 + ) * 2 def add(self, a, b): return (a + b) % self.max_number @@ -86,7 +94,12 @@ def get_all_prompts(self): for diff in range(0, self.max_diff + 1): second_num = self.add(first_num, diff) for num_to_complete in range(self.max_num_in_response + 1): - prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}" + prompt = ( + str(first_num) + + "," + + str(second_num) + + f":{self.max_number},{num_to_complete}" + ) all_prompts.append(prompt) return all_prompts @@ -96,7 +109,12 @@ def sample_str_prompts(self): diff = self.np_rng.integers(self.max_diff + 1) second_num = self.add(first_num, diff) num_to_complete = self.np_rng.integers(self.max_num_in_response + 1) - prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}" + prompt = ( + str(first_num) + + "," + + str(second_num) + + f":{self.max_number},{num_to_complete}" + ) return prompt def sample_batch_str_prompts(self, batch_size): @@ -140,10 +158,14 @@ def compute_reward(prompt: str, response: str, sequence_reward=1.0): """We compute dense reward here so that we can directly train RL without SFT""" response_length = len(response) ground_truth_response = generate_ground_truth_response(prompt) - per_token_reward = sequence_reward / (len(ground_truth_response) + 1) # including [EOS] + per_token_reward = sequence_reward / ( + len(ground_truth_response) + 1 + ) # including [EOS] # pad - reward = np.zeros(response_length, dtype=np.float32) # this assumes that each char is a token + reward = np.zeros( + response_length, dtype=np.float32 + ) # this assumes that each char is a token # assign reward until mismatches ground_truth_idx = 0 for i in range(response_length): diff --git a/Agent0/executor_train/verl/tests/special_e2e/envs/digit_completion/tokenizer.py b/Agent0/executor_train/verl/tests/special_e2e/envs/digit_completion/tokenizer.py index 6ff4719..1242f31 100644 --- a/Agent0/executor_train/verl/tests/special_e2e/envs/digit_completion/tokenizer.py +++ b/Agent0/executor_train/verl/tests/special_e2e/envs/digit_completion/tokenizer.py @@ -27,7 +27,9 @@ class CharTokenizer(PreTrainedTokenizer): - def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs): + def __init__( + self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs + ): """Character tokenizer for Hugging Face transformers. Args: diff --git a/Agent0/executor_train/verl/tests/special_e2e/sft/test_sp_loss_match.py b/Agent0/executor_train/verl/tests/special_e2e/sft/test_sp_loss_match.py index 4dc0cbd..e11a862 100644 --- a/Agent0/executor_train/verl/tests/special_e2e/sft/test_sp_loss_match.py +++ b/Agent0/executor_train/verl/tests/special_e2e/sft/test_sp_loss_match.py @@ -29,8 +29,12 @@ def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = total_steps: Number of steps to test (default: 4) """ if trainer.device_mesh.get_rank() == 0: - print("\nStarting debug comparison between original and SP+rmpad forward passes...") - print(f"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}") + print( + "\nStarting debug comparison between original and SP+rmpad forward passes..." + ) + print( + f"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}" + ) print(f"Remove padding: {trainer.use_remove_padding}\n") steps_remaining = total_steps @@ -38,7 +42,9 @@ def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = for epoch in range(1): # Just one epoch for testing trainer.train_sampler.set_epoch(epoch=epoch) for data in trainer.train_dataloader: - data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda() + data = TensorDict( + data, batch_size=trainer.config.data.train_batch_size + ).cuda() trainer.fsdp_model.train() micro_batches = data.split(trainer.config.data.micro_batch_size_per_gpu) @@ -51,21 +57,31 @@ def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = trainer.use_remove_padding = False old_sp = trainer.config.ulysses_sequence_parallel_size trainer.config.ulysses_sequence_parallel_size = 1 - loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) + loss_ref = trainer._compute_loss_and_backward( + micro_batch.copy(), do_backward=False + ) # Do SP and rmpad trainer.config.ulysses_sequence_parallel_size = old_sp trainer.use_remove_padding = True - loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) + loss_sp = trainer._compute_loss_and_backward( + micro_batch.copy(), do_backward=False + ) # Collect losses across all ranks loss_ref_all = loss_ref.clone() loss_sp_all = loss_sp.clone() - torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG) - torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce( + loss_ref_all, op=torch.distributed.ReduceOp.AVG + ) + torch.distributed.all_reduce( + loss_sp_all, op=torch.distributed.ReduceOp.AVG + ) # Calculate relative difference of averaged losses - rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8) + rel_diff = torch.abs(loss_ref_all - loss_sp_all) / ( + torch.abs(loss_ref_all) + 1e-8 + ) if trainer.device_mesh.get_rank() == 0: print("\nComparison Results (Averaged across ranks):") @@ -73,7 +89,9 @@ def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = print(f"SP+rmpad Loss: {loss_sp_all.item():.6f}") print(f"Relative Difference: {rel_diff.item():.6f}") - assert rel_diff.item() < 1e-2, "Significant difference detected between averaged losses!" + assert ( + rel_diff.item() < 1e-2 + ), "Significant difference detected between averaged losses!" print("Loss difference is within the acceptable range.") steps_remaining -= 1 @@ -98,11 +116,15 @@ def create_trainer(config): """ local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) + device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",) + ) dp_size = world_size // config.ulysses_sequence_parallel_size ulysses_device_mesh = init_device_mesh( - device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp") + device_type="cuda", + mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), + mesh_dim_names=("dp", "sp"), ) # build tokenizer and datasets first @@ -111,7 +133,9 @@ def create_trainer(config): from verl.utils.fs import copy_to_local local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) - tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) + tokenizer = hf_tokenizer( + local_model_path, trust_remote_code=config.model.trust_remote_code + ) train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) diff --git a/Agent0/executor_train/verl/tests/special_sanity/check_api_docs.py b/Agent0/executor_train/verl/tests/special_sanity/check_api_docs.py index fa31ec8..aa7a4af 100644 --- a/Agent0/executor_train/verl/tests/special_sanity/check_api_docs.py +++ b/Agent0/executor_train/verl/tests/special_sanity/check_api_docs.py @@ -55,7 +55,9 @@ def iter_submodules(root: ModuleType) -> Iterable[ModuleType]: """Yield *root* and every sub-module inside it.""" yield root if getattr(root, "__path__", None): # only packages have __path__ - for mod_info in pkgutil.walk_packages(root.__path__, prefix=f"{root.__name__}."): + for mod_info in pkgutil.walk_packages( + root.__path__, prefix=f"{root.__name__}." + ): try: yield importlib.import_module(mod_info.name) except Exception as exc: # noqa: BLE001 @@ -116,7 +118,9 @@ def main() -> None: targets = args.modules or autodiscover_packages() if not targets: - raise ValueError("[error] No modules specified and none detected automatically.") + raise ValueError( + "[error] No modules specified and none detected automatically." + ) all_missing: list[str] = [] for modname in targets: @@ -126,7 +130,9 @@ def main() -> None: print("\nMissing docstrings:") for name in sorted(all_missing): print(f" - {name}") - raise ValueError("Missing docstrings detected. Please enhance them with docs accordingly.") + raise ValueError( + "Missing docstrings detected. Please enhance them with docs accordingly." + ) print("โœ… All exported functions/classes have docstrings.") diff --git a/Agent0/executor_train/verl/tests/special_sanity/check_device_api_usage.py b/Agent0/executor_train/verl/tests/special_sanity/check_device_api_usage.py index c8988db..bdc8ee2 100644 --- a/Agent0/executor_train/verl/tests/special_sanity/check_device_api_usage.py +++ b/Agent0/executor_train/verl/tests/special_sanity/check_device_api_usage.py @@ -65,7 +65,9 @@ # for easy debugging in non-linux system sw = sw.replace("/", os.sep) if sw in path_in_str: - print(f"[SKIP] File {path_in_str} is in device api usage check whitelist, checking is skipped.") + print( + f"[SKIP] File {path_in_str} is in device api usage check whitelist, checking is skipped." + ) path_in_whitelist = True break diff --git a/Agent0/executor_train/verl/tests/special_sanity/check_docs_time_info.py b/Agent0/executor_train/verl/tests/special_sanity/check_docs_time_info.py index a54d1d5..ebaa8be 100644 --- a/Agent0/executor_train/verl/tests/special_sanity/check_docs_time_info.py +++ b/Agent0/executor_train/verl/tests/special_sanity/check_docs_time_info.py @@ -51,7 +51,10 @@ def is_allowed(path: Path) -> bool: def main(): if not DOCS_DIR.exists(): - print(f"Error: Documentation directory '{DOCS_DIR}' does not exist.", file=sys.stderr) + print( + f"Error: Documentation directory '{DOCS_DIR}' does not exist.", + file=sys.stderr, + ) sys.exit(1) missing = [] diff --git a/Agent0/executor_train/verl/tests/special_sanity/check_docstrings.py b/Agent0/executor_train/verl/tests/special_sanity/check_docstrings.py index 7c5d8ed..26060fe 100644 --- a/Agent0/executor_train/verl/tests/special_sanity/check_docstrings.py +++ b/Agent0/executor_train/verl/tests/special_sanity/check_docstrings.py @@ -35,7 +35,11 @@ def visit_FunctionDef(self, node: ast.FunctionDef): """Visit function definitions and check for docstrings.""" if not node.name.startswith("_") and self.function_nesting_level == 0: if not self._has_docstring(node): - func_name = f"{self.current_class}.{node.name}" if self.current_class else node.name + func_name = ( + f"{self.current_class}.{node.name}" + if self.current_class + else node.name + ) self.missing_docstrings.append((func_name, self.filename, node.lineno)) self.function_nesting_level += 1 @@ -46,7 +50,11 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): """Visit async function definitions and check for docstrings.""" if not node.name.startswith("_") and self.function_nesting_level == 0: if not self._has_docstring(node): - func_name = f"{self.current_class}.{node.name}" if self.current_class else node.name + func_name = ( + f"{self.current_class}.{node.name}" + if self.current_class + else node.name + ) self.missing_docstrings.append((func_name, self.filename, node.lineno)) self.function_nesting_level += 1 @@ -130,7 +138,9 @@ def main(): print("=" * 60) if all_missing_docstrings: - print(f"\nSUMMARY: Found {len(all_missing_docstrings)} functions/classes missing docstrings:") + print( + f"\nSUMMARY: Found {len(all_missing_docstrings)} functions/classes missing docstrings:" + ) print("-" * 60) by_file = {} @@ -146,7 +156,9 @@ def main(): print(f"\nTotal missing docstrings: {len(all_missing_docstrings)}") - raise Exception(f"Found {len(all_missing_docstrings)} functions/classes without proper docstrings!") + raise Exception( + f"Found {len(all_missing_docstrings)} functions/classes without proper docstrings!" + ) else: print("\nโœ… All functions and classes have proper docstrings!") diff --git a/Agent0/executor_train/verl/tests/special_sanity/check_pr_description.py b/Agent0/executor_train/verl/tests/special_sanity/check_pr_description.py index 4ed4563..07587a4 100644 --- a/Agent0/executor_train/verl/tests/special_sanity/check_pr_description.py +++ b/Agent0/executor_train/verl/tests/special_sanity/check_pr_description.py @@ -34,7 +34,9 @@ class PRDescriptionError(Exception): # Path to the PR template file -template_file = os.path.join(os.getenv("GITHUB_WORKSPACE", "."), ".github", "PULL_REQUEST_TEMPLATE.md") +template_file = os.path.join( + os.getenv("GITHUB_WORKSPACE", "."), ".github", "PULL_REQUEST_TEMPLATE.md" +) def load_template(path): @@ -52,7 +54,9 @@ def load_template(path): lines.append(line.strip()) return lines except Exception as e: - raise TemplateFileError(f"Failed to read PR template (first {NUM_LINES} lines) at {path}: {e}") from e + raise TemplateFileError( + f"Failed to read PR template (first {NUM_LINES} lines) at {path}: {e}" + ) from e def load_pr_body(event_path): diff --git a/Agent0/executor_train/verl/tests/special_sanity/check_pr_title.py b/Agent0/executor_train/verl/tests/special_sanity/check_pr_title.py index f4cbd52..d4ed666 100644 --- a/Agent0/executor_train/verl/tests/special_sanity/check_pr_title.py +++ b/Agent0/executor_train/verl/tests/special_sanity/check_pr_title.py @@ -22,7 +22,17 @@ allowed_modules = ["fsdp", "megatron", "sglang", "vllm", "rollout", "trainer"] allowed_modules += ["tests", "training_utils", "recipe", "hardware", "deployment"] allowed_modules += ["ray", "worker", "single_controller", "misc", "docker", "ci"] -allowed_modules += ["perf", "model", "algo", "env", "tool", "ckpt", "doc", "data", "cfg"] +allowed_modules += [ + "perf", + "model", + "algo", + "env", + "tool", + "ckpt", + "doc", + "data", + "cfg", +] allowed_types = ["feat", "fix", "refactor", "chore", "test"] # Check for [BREAKING] prefix and extract the rest of the title @@ -45,13 +55,17 @@ else: modules = re.findall(r"[a-z_]+", re_modules.group(1).lower()) if not all(module in allowed_modules for module in modules): - invalid_modules = [module for module in modules if module not in allowed_modules] + invalid_modules = [ + module for module in modules if module not in allowed_modules + ] print(f"โŒ Invalid modules: {', '.join(invalid_modules)}") print(f"Allowed modules: {', '.join(allowed_modules)}") raise Exception("Invalid PR title") types_pattern = "|".join(re.escape(t) for t in allowed_types) -re_types_pattern = re.compile(rf"^\[[a-z_,\s]+\]\s+({types_pattern}):\s+.+$", re.IGNORECASE) +re_types_pattern = re.compile( + rf"^\[[a-z_,\s]+\]\s+({types_pattern}):\s+.+$", re.IGNORECASE +) match = re_types_pattern.match(core_pr_title) if not match: @@ -64,4 +78,6 @@ # Build the success message breaking_info = " (BREAKING CHANGE)" if is_breaking else "" -print(f"โœ… PR title is valid: {pr_title}, modules: {modules}, type: {change_type}{breaking_info}") +print( + f"โœ… PR title is valid: {pr_title}, modules: {modules}, type: {change_type}{breaking_info}" +) diff --git a/Agent0/executor_train/verl/tests/special_sanity/test_config_docs.py b/Agent0/executor_train/verl/tests/special_sanity/test_config_docs.py index 2f260f1..cfd099f 100644 --- a/Agent0/executor_train/verl/tests/special_sanity/test_config_docs.py +++ b/Agent0/executor_train/verl/tests/special_sanity/test_config_docs.py @@ -41,7 +41,9 @@ def validate_yaml_format(yaml_lines): comment_index = line.index("#") colon_index = line.index(":") if comment_index > colon_index: - errors.append(f"Inline comment found on line {i + 1}: {line.strip()}") + errors.append( + f"Inline comment found on line {i + 1}: {line.strip()}" + ) # Check for blank line after this key line (unless next is a deeper indent) if i + 1 < len(yaml_lines): @@ -50,7 +52,9 @@ def validate_yaml_format(yaml_lines): # If next is not empty and not a deeper nested line, enforce blank line if next_stripped != "": - errors.append(f"Missing blank line after line {i + 1}: {line.strip()}") + errors.append( + f"Missing blank line after line {i + 1}: {line.strip()}" + ) i += 1 @@ -76,7 +80,9 @@ def test_trainer_config_doc(): if validation_errors: success = False print("YAML documentation format check failed:") - print(f"Please read the top block of {yaml_to_inspect} to see format rules:\n") + print( + f"Please read the top block of {yaml_to_inspect} to see format rules:\n" + ) for err in validation_errors: print(" -", err) diff --git a/Agent0/executor_train/verl/tests/special_sanity/type_coverage_check.py b/Agent0/executor_train/verl/tests/special_sanity/type_coverage_check.py index dc6dc7c..f8a3fa3 100644 --- a/Agent0/executor_train/verl/tests/special_sanity/type_coverage_check.py +++ b/Agent0/executor_train/verl/tests/special_sanity/type_coverage_check.py @@ -26,7 +26,9 @@ def get_changed_files() -> list[Path]: result = subprocess.run( - ["git", "diff", "--name-only", "--diff-filter=AM", "origin/main...HEAD"], stdout=subprocess.PIPE, text=True + ["git", "diff", "--name-only", "--diff-filter=AM", "origin/main...HEAD"], + stdout=subprocess.PIPE, + text=True, ) return [Path(f) for f in result.stdout.splitlines() if f.endswith(".py")] @@ -70,14 +72,25 @@ def has_type_annotations(node: ast.AST, debug: bool = False) -> int: if isinstance(node, ast.FunctionDef): is_private = node.name.startswith("_") has_ann = ( - all(arg.annotation is not None for arg in node.args.args if should_check_type(arg.arg)) + all( + arg.annotation is not None + for arg in node.args.args + if should_check_type(arg.arg) + ) and node.returns is not None ) if has_ann or is_private: return CHECK_SUCCESS else: if debug: - print(node, [(arg.annotation, arg.arg) for arg in node.args.args if should_check_type(arg.arg)]) + print( + node, + [ + (arg.annotation, arg.arg) + for arg in node.args.args + if should_check_type(arg.arg) + ], + ) return CHECK_FAILURE return CHECK_SUCCESS @@ -102,7 +115,11 @@ def check_file( annotated += 1 if result == CHECK_WARNING: warning_lines.append( - (file_path, node.lineno, linecache.getline(str(file_path), node.lineno).strip()) + ( + file_path, + node.lineno, + linecache.getline(str(file_path), node.lineno).strip(), + ) ) else: source_line = linecache.getline(str(file_path), node.lineno).strip() @@ -114,9 +131,17 @@ def check_file( def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( - "--threshold", type=float, default=0.3, help="Minimum ratio of annotated lines required (0.0 - 1.0)" + "--threshold", + type=float, + default=0.3, + help="Minimum ratio of annotated lines required (0.0 - 1.0)", + ) + parser.add_argument( + "--target-file", + type=str, + default=None, + help="Path to the Python source file to analyse", ) - parser.add_argument("--target-file", type=str, default=None, help="Path to the Python source file to analyse") parser.add_argument( "--all-lines", action="store_true", @@ -130,7 +155,9 @@ def main() -> None: all_warnings: list[tuple[Path, int, str]] = [] all_failures: list[tuple[Path, int, str]] = [] - target_files = [args.target_file] if args.target_file is not None else get_changed_files() + target_files = ( + [args.target_file] if args.target_file is not None else get_changed_files() + ) for fpath in target_files: if "tests/" in str(fpath): continue @@ -138,7 +165,9 @@ def main() -> None: changed_lines = [i + 1 for i in range(len(open(fpath).readlines()))] else: changed_lines = get_changed_lines(fpath) - annotated, total, warning_lines, failure_lines = check_file(fpath, changed_lines, args.debug) + annotated, total, warning_lines, failure_lines = check_file( + fpath, changed_lines, args.debug + ) total_annotated += annotated total_changed += total all_warnings.extend(warning_lines) @@ -152,7 +181,9 @@ def main() -> None: ) if all_warnings: - print("\nโš ๏ธ Suggest Improve: Lines missing type annotations for inputs and outputs:\n") + print( + "\nโš ๏ธ Suggest Improve: Lines missing type annotations for inputs and outputs:\n" + ) for fname, lineno, line in all_warnings: print(f"{fname}:{lineno}: {line}") diff --git a/Agent0/executor_train/verl/tests/special_sanity/validate_imported_docs.py b/Agent0/executor_train/verl/tests/special_sanity/validate_imported_docs.py index b36a407..c814cac 100644 --- a/Agent0/executor_train/verl/tests/special_sanity/validate_imported_docs.py +++ b/Agent0/executor_train/verl/tests/special_sanity/validate_imported_docs.py @@ -30,7 +30,9 @@ def _parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser(description="Verify that imported functions/classes have docstrings.") + p = argparse.ArgumentParser( + description="Verify that imported functions/classes have docstrings." + ) p.add_argument( "--target-file", default="verl/trainer/ppo/ray_trainer.py", @@ -60,7 +62,9 @@ def _import_attr(module_name: str, attr_name: str): return getattr(module, attr_name) -def _check_file(py_file: pathlib.Path, project_root: pathlib.Path, allow_list: list[str]) -> list[str]: +def _check_file( + py_file: pathlib.Path, project_root: pathlib.Path, allow_list: list[str] +) -> list[str]: """Return a list of error strings (empty == success).""" # Ensure local packages resolve sys.path.insert(0, str(project_root.resolve())) @@ -123,7 +127,9 @@ def main() -> None: raise Exception("โŒ Docstring verification failed.") if not args.quiet: - print(f"โœ… All explicitly imported functions/classes in {target_path} have docstrings.") + print( + f"โœ… All explicitly imported functions/classes in {target_path} have docstrings." + ) if __name__ == "__main__": diff --git a/Agent0/executor_train/verl/tests/special_sanity/validate_structure.py b/Agent0/executor_train/verl/tests/special_sanity/validate_structure.py index a5390b1..a61e0da 100644 --- a/Agent0/executor_train/verl/tests/special_sanity/validate_structure.py +++ b/Agent0/executor_train/verl/tests/special_sanity/validate_structure.py @@ -43,7 +43,9 @@ def discover_allowed_modules(impl_root: Path, extra: list[str]) -> set[str]: return allowed -def find_violations(tests_root: Path, allowed: set[str], allowed_files: list[str]) -> list[str]: +def find_violations( + tests_root: Path, allowed: set[str], allowed_files: list[str] +) -> list[str]: """Return a list of error strings for test files in the wrong place.""" errors: list[str] = [] for test_file in tests_root.rglob("test*.py"): @@ -51,7 +53,9 @@ def find_violations(tests_root: Path, allowed: set[str], allowed_files: list[str continue rel_parts = test_file.relative_to(tests_root).parts if len(rel_parts) < 2: - errors.append(f"{test_file}: must be inside one of {sorted(allowed)} (not at tests root)") + errors.append( + f"{test_file}: must be inside one of {sorted(allowed)} (not at tests root)" + ) continue first_folder = rel_parts[0] @@ -64,7 +68,9 @@ def find_violations(tests_root: Path, allowed: set[str], allowed_files: list[str def main() -> None: - parser = argparse.ArgumentParser(description="Check that test files follow tests//โ€ฆ layout.") + parser = argparse.ArgumentParser( + description="Check that test files follow tests//โ€ฆ layout." + ) parser.add_argument( "--impl-root", type=Path, @@ -80,7 +86,12 @@ def main() -> None: parser.add_argument( "--allow-dirs", nargs="*", - default=["special_e2e", "special_sanity", "special_standalone", "special_distributed"], + default=[ + "special_e2e", + "special_sanity", + "special_standalone", + "special_distributed", + ], help="Extra top-level test folders that are exempt from the rule", ) parser.add_argument( diff --git a/Agent0/executor_train/verl/tests/special_standalone/test_memory_buffers.py b/Agent0/executor_train/verl/tests/special_standalone/test_memory_buffers.py index 7785153..83de78d 100644 --- a/Agent0/executor_train/verl/tests/special_standalone/test_memory_buffers.py +++ b/Agent0/executor_train/verl/tests/special_standalone/test_memory_buffers.py @@ -43,7 +43,9 @@ def test_memory_buffers(): r_before = torch.cuda.memory_reserved(0) / norm_factor a_before = torch.cuda.memory_allocated(0) / norm_factor - print(f"Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB") + print( + f"Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB" + ) t = torch.cuda.get_device_properties(0).total_memory / norm_factor r = torch.cuda.memory_reserved(0) / norm_factor @@ -55,11 +57,17 @@ def test_memory_buffers(): print(f"After Total memory: {t} GB, reserved: {r} GB, allocated: {a} GB") change_ratio = (a - a_before) / a_before - assert change_ratio < 0.01, f"make sure the allocated change is less than 1%, Got {change_ratio}" + assert ( + change_ratio < 0.01 + ), f"make sure the allocated change is less than 1%, Got {change_ratio}" - for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters(), strict=True): + for (name1, param1), (name2, param2) in zip( + model.named_parameters(), model_copy.named_parameters(), strict=True + ): assert name1 == name2 - assert torch.eq(param1.data, param2.data).all(), f"{param1.data}, {param2.data}, {name1}" + assert torch.eq( + param1.data, param2.data + ).all(), f"{param1.data}, {param2.data}, {name1}" if __name__ == "__main__": diff --git a/Agent0/executor_train/verl/tests/test_protocol_on_cpu.py b/Agent0/executor_train/verl/tests/test_protocol_on_cpu.py index 2052635..0bff12c 100644 --- a/Agent0/executor_train/verl/tests/test_protocol_on_cpu.py +++ b/Agent0/executor_train/verl/tests/test_protocol_on_cpu.py @@ -27,10 +27,14 @@ def test_union_tensor_dict(): obs = torch.randn(100, 10) data1 = TensorDict({"obs": obs, "act": torch.randn(100, 3)}, batch_size=[100]) - data2 = TensorDict({"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100]) + data2 = TensorDict( + {"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, + batch_size=[100], + ) data_with_copied_obs = TensorDict( - {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100] + {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, + batch_size=[100], ) data = union_tensor_dict(data1, data2) @@ -87,7 +91,9 @@ def test_tensor_dict_make_iterator(): print(data1.batch["obs"]) print(data2.batch["obs"]) raise AssertionError() - non_tensor_result = np.all(np.equal(data1.non_tensor_batch["labels"], data2.non_tensor_batch["labels"])) + non_tensor_result = np.all( + np.equal(data1.non_tensor_batch["labels"], data2.non_tensor_batch["labels"]) + ) if not non_tensor_result.item(): print(data1.non_tensor_batch["labels"]) print(data2.non_tensor_batch["labels"]) @@ -96,18 +102,28 @@ def test_tensor_dict_make_iterator(): def test_reorder(): obs = torch.tensor([1, 2, 3, 4, 5, 6]) labels = ["a", "b", "c", "d", "e", "f"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"}) + data = DataProto.from_dict( + tensors={"obs": obs}, + non_tensors={"labels": labels}, + meta_info={"name": "abdce"}, + ) data.reorder(torch.tensor([3, 4, 2, 0, 1, 5])) assert torch.all(torch.eq(data.batch["obs"], torch.tensor([4, 5, 3, 1, 2, 6]))) - assert np.all(data.non_tensor_batch["labels"] == np.array(["d", "e", "c", "a", "b", "f"])) + assert np.all( + data.non_tensor_batch["labels"] == np.array(["d", "e", "c", "a", "b", "f"]) + ) assert data.meta_info == {"name": "abdce"} def test_chunk_concat(): obs = torch.tensor([1, 2, 3, 4, 5, 6]) labels = ["a", "b", "c", "d", "e", "f"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"}) + data = DataProto.from_dict( + tensors={"obs": obs}, + non_tensors={"labels": labels}, + meta_info={"name": "abdce"}, + ) with pytest.raises(AssertionError): data.chunk(5) @@ -124,7 +140,9 @@ def test_chunk_concat(): concat_data = DataProto.concat(data_split) assert torch.all(torch.eq(concat_data.batch["obs"], data.batch["obs"])) - assert np.all(concat_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"]) + assert np.all( + concat_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"] + ) assert concat_data.meta_info == data.meta_info @@ -145,31 +163,53 @@ def test_repeat(): # Create a DataProto object with some batch and non-tensor data obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) + data = DataProto.from_dict( + tensors={"obs": obs}, + non_tensors={"labels": labels}, + meta_info={"info": "test_info"}, + ) # Test interleave=True repeated_data_interleave = data.repeat(repeat_times=2, interleave=True) - expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]]) + expected_obs_interleave = torch.tensor( + [[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]] + ) expected_labels_interleave = ["a", "a", "b", "b", "c", "c"] - assert torch.all(torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave)) - assert (repeated_data_interleave.non_tensor_batch["labels"] == expected_labels_interleave).all() + assert torch.all( + torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave) + ) + assert ( + repeated_data_interleave.non_tensor_batch["labels"] + == expected_labels_interleave + ).all() assert repeated_data_interleave.meta_info == {"info": "test_info"} # Test interleave=False repeated_data_no_interleave = data.repeat(repeat_times=2, interleave=False) - expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]]) + expected_obs_no_interleave = torch.tensor( + [[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]] + ) expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"] - assert torch.all(torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave)) - assert (repeated_data_no_interleave.non_tensor_batch["labels"] == expected_labels_no_interleave).all() + assert torch.all( + torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave) + ) + assert ( + repeated_data_no_interleave.non_tensor_batch["labels"] + == expected_labels_no_interleave + ).all() assert repeated_data_no_interleave.meta_info == {"info": "test_info"} def test_dataproto_pad_unpad(): obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) + data = DataProto.from_dict( + tensors={"obs": obs}, + non_tensors={"labels": labels}, + meta_info={"info": "test_info"}, + ) from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto @@ -206,7 +246,9 @@ def test_dataproto_pad_unpad(): padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7) assert pad_size == 4 - expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]]) + expected_obs = torch.tensor( + [[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]] + ) expected_labels = ["a", "b", "c", "a", "b", "c", "a"] assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs)) assert (padded_data.non_tensor_batch["labels"] == expected_labels).all() @@ -223,20 +265,32 @@ def test_dataproto_fold_unfold(): obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) + data = DataProto.from_dict( + tensors={"obs": obs}, + non_tensors={"labels": labels}, + meta_info={"info": "test_info"}, + ) data1 = data.repeat(repeat_times=2, interleave=True) data2 = fold_batch_dim(data1, new_batch_size=3) - torch.testing.assert_close(data2.batch["obs"], torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]])) - assert (data2.non_tensor_batch["labels"] == [["a", "a"], ["b", "b"], ["c", "c"]]).all() + torch.testing.assert_close( + data2.batch["obs"], + torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]), + ) + assert ( + data2.non_tensor_batch["labels"] == [["a", "a"], ["b", "b"], ["c", "c"]] + ).all() data2.reorder(indices=torch.tensor([1, 2, 0])) data3 = unfold_batch_dim(data2, batch_dims=2) - torch.testing.assert_close(data3.batch["obs"], torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]])) + torch.testing.assert_close( + data3.batch["obs"], + torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]]), + ) assert (data3.non_tensor_batch["labels"] == ["b", "b", "c", "c", "a", "a"]).all() assert data3.meta_info == {"info": "test_info"} @@ -244,12 +298,18 @@ def test_dataproto_fold_unfold(): def test_torch_save_data_proto(): obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) + data = DataProto.from_dict( + tensors={"obs": obs}, + non_tensors={"labels": labels}, + meta_info={"info": "test_info"}, + ) data.save_to_disk("test_data.pt") loaded_data = DataProto.load_from_disk("test_data.pt") assert torch.all(torch.eq(loaded_data.batch["obs"], data.batch["obs"])) - assert (loaded_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"]).all() + assert ( + loaded_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"] + ).all() assert loaded_data.meta_info == data.meta_info import os @@ -260,11 +320,17 @@ def test_torch_save_data_proto(): def test_len(): obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = np.array(["a", "b", "c"], dtype=object) - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) + data = DataProto.from_dict( + tensors={"obs": obs}, + non_tensors={"labels": labels}, + meta_info={"info": "test_info"}, + ) assert len(data) == 3 - data = DataProto(batch=None, non_tensor_batch={"labels": labels}, meta_info={"info": "test_info"}) + data = DataProto( + batch=None, non_tensor_batch={"labels": labels}, meta_info={"info": "test_info"} + ) assert len(data) == 3 @@ -292,8 +358,12 @@ def test_dataproto_index(): assert result_np_int.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_np_int.batch["obs"].shape[0] == idx_num assert result_np_int.non_tensor_batch["labels"].shape[0] == idx_num - assert np.array_equal(result_np_int.batch["obs"].cpu().numpy(), obs[idx_np_int].numpy()) - assert np.array_equal(result_np_int.non_tensor_batch["labels"], labels_np[idx_np_int]) + assert np.array_equal( + result_np_int.batch["obs"].cpu().numpy(), obs[idx_np_int].numpy() + ) + assert np.array_equal( + result_np_int.non_tensor_batch["labels"], labels_np[idx_np_int] + ) idx_torch_int = torch.randint(0, data_len, size=(idx_num,)) result_torch_int = data[idx_torch_int] @@ -301,8 +371,13 @@ def test_dataproto_index(): assert result_torch_int.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_torch_int.batch["obs"].shape[0] == idx_num assert result_torch_int.non_tensor_batch["labels"].shape[0] == idx_num - assert np.array_equal(result_torch_int.batch["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy()) - assert np.array_equal(result_torch_int.non_tensor_batch["labels"], labels_np[idx_torch_int.cpu().numpy()]) + assert np.array_equal( + result_torch_int.batch["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy() + ) + assert np.array_equal( + result_torch_int.non_tensor_batch["labels"], + labels_np[idx_torch_int.cpu().numpy()], + ) idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)] result_list_int = data[idx_list_int] @@ -310,8 +385,12 @@ def test_dataproto_index(): assert result_list_int.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_list_int.batch["obs"].shape[0] == idx_num assert result_list_int.non_tensor_batch["labels"].shape[0] == idx_num - assert np.array_equal(result_list_int.batch["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy()) - assert np.array_equal(result_list_int.non_tensor_batch["labels"], labels_np[idx_list_int]) + assert np.array_equal( + result_list_int.batch["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy() + ) + assert np.array_equal( + result_list_int.non_tensor_batch["labels"], labels_np[idx_list_int] + ) idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool) result_np_bool = data[idx_np_bool] @@ -319,17 +398,28 @@ def test_dataproto_index(): assert result_np_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_np_bool.batch["obs"].shape[0] == idx_np_bool.sum() assert result_np_bool.non_tensor_batch["labels"].shape[0] == idx_np_bool.sum() - assert np.array_equal(result_np_bool.batch["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy()) - assert np.array_equal(result_np_bool.non_tensor_batch["labels"], labels_np[idx_np_bool]) + assert np.array_equal( + result_np_bool.batch["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy() + ) + assert np.array_equal( + result_np_bool.non_tensor_batch["labels"], labels_np[idx_np_bool] + ) idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool) result_torch_bool = data[idx_torch_bool] assert result_torch_bool.batch.keys() == data.batch.keys() assert result_torch_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_torch_bool.batch["obs"].shape[0] == idx_torch_bool.sum().item() - assert result_torch_bool.non_tensor_batch["labels"].shape[0] == idx_torch_bool.sum().item() - assert np.array_equal(result_torch_bool.batch["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy()) - assert np.array_equal(result_torch_bool.non_tensor_batch["labels"], labels_np[idx_torch_bool]) + assert ( + result_torch_bool.non_tensor_batch["labels"].shape[0] + == idx_torch_bool.sum().item() + ) + assert np.array_equal( + result_torch_bool.batch["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy() + ) + assert np.array_equal( + result_torch_bool.non_tensor_batch["labels"], labels_np[idx_torch_bool] + ) idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)] result_list_bool = data[idx_list_bool] @@ -337,8 +427,12 @@ def test_dataproto_index(): assert result_list_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys() assert result_list_bool.batch["obs"].shape[0] == sum(idx_list_bool) assert result_list_bool.non_tensor_batch["labels"].shape[0] == sum(idx_list_bool) - assert np.array_equal(result_list_bool.batch["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy()) - assert np.array_equal(result_list_bool.non_tensor_batch["labels"], labels_np[idx_list_bool]) + assert np.array_equal( + result_list_bool.batch["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy() + ) + assert np.array_equal( + result_list_bool.non_tensor_batch["labels"], labels_np[idx_list_bool] + ) def test_old_vs_new_from_single_dict(): @@ -380,7 +474,9 @@ def from_single_dict(cls, data, meta_info=None, auto_padding=False): def test_dataproto_no_batch(): labels = ["a", "b", "c"] - data = DataProto.from_dict(non_tensors={"labels": labels}, meta_info={"info": "test_info"}) + data = DataProto.from_dict( + non_tensors={"labels": labels}, meta_info={"info": "test_info"} + ) selected = data.select(non_tensor_batch_keys=["labels"]) assert (selected.non_tensor_batch["labels"] == labels).all() pop_data = data.pop(non_tensor_batch_keys=["labels"]) @@ -392,24 +488,44 @@ def test_sample_level_repeat(): # Create a DataProto object with some batch and non-tensor data obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) + data = DataProto.from_dict( + tensors={"obs": obs}, + non_tensors={"labels": labels}, + meta_info={"info": "test_info"}, + ) # list repeated_data_interleave = data.sample_level_repeat(repeat_times=[3, 1, 2]) - expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]]) + expected_obs_interleave = torch.tensor( + [[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]] + ) expected_labels_interleave = ["a", "a", "a", "b", "c", "c"] - assert torch.all(torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave)) - assert (repeated_data_interleave.non_tensor_batch["labels"] == expected_labels_interleave).all() + assert torch.all( + torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave) + ) + assert ( + repeated_data_interleave.non_tensor_batch["labels"] + == expected_labels_interleave + ).all() assert repeated_data_interleave.meta_info == {"info": "test_info"} # torch.tensor - repeated_data_no_interleave = data.sample_level_repeat(repeat_times=torch.tensor([1, 2, 3])) - expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]]) + repeated_data_no_interleave = data.sample_level_repeat( + repeat_times=torch.tensor([1, 2, 3]) + ) + expected_obs_no_interleave = torch.tensor( + [[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]] + ) expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"] - assert torch.all(torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave)) - assert (repeated_data_no_interleave.non_tensor_batch["labels"] == expected_labels_no_interleave).all() + assert torch.all( + torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave) + ) + assert ( + repeated_data_no_interleave.non_tensor_batch["labels"] + == expected_labels_no_interleave + ).all() assert repeated_data_no_interleave.meta_info == {"info": "test_info"} @@ -419,7 +535,9 @@ def test_dataproto_unfold_column_chunks(): labels = ["a", "b", "c"] data = DataProto.from_dict( - tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} + tensors={"obs1": obs1, "obs2": obs2}, + non_tensors={"labels": labels}, + meta_info={"name": "abc"}, ) ret = data.unfold_column_chunks(2, split_keys=["obs1"]) @@ -436,7 +554,9 @@ def test_dataproto_unfold_column_chunks(): labels = [["a1", "a2"], ["b1", "b2"], ["c1", "c2"]] data = DataProto.from_dict( - tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} + tensors={"obs1": obs1, "obs2": obs2}, + non_tensors={"labels": labels}, + meta_info={"name": "abc"}, ) ret = data.unfold_column_chunks(2, split_keys=["obs1", "labels"]) @@ -449,13 +569,19 @@ def test_dataproto_unfold_column_chunks(): assert ret.meta_info == {"name": "abc"} obs1 = torch.tensor( - [[[1, 1], [2, 2], [3, 3], [4, 4]], [[5, 5], [6, 6], [7, 7], [8, 8]], [[9, 9], [10, 10], [11, 11], [12, 12]]] + [ + [[1, 1], [2, 2], [3, 3], [4, 4]], + [[5, 5], [6, 6], [7, 7], [8, 8]], + [[9, 9], [10, 10], [11, 11], [12, 12]], + ] ) obs2 = torch.tensor([[[1, 1], [2, 2]], [[5, 5], [6, 6]], [[9, 9], [10, 10]]]) labels = ["a", "b", "c"] data = DataProto.from_dict( - tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} + tensors={"obs1": obs1, "obs2": obs2}, + non_tensors={"labels": labels}, + meta_info={"name": "abc"}, ) ret = data.unfold_column_chunks(2, split_keys=["obs1"]) @@ -470,7 +596,14 @@ def test_dataproto_unfold_column_chunks(): ] ) expect_obs2 = torch.tensor( - [[[1, 1], [2, 2]], [[1, 1], [2, 2]], [[5, 5], [6, 6]], [[5, 5], [6, 6]], [[9, 9], [10, 10]], [[9, 9], [10, 10]]] + [ + [[1, 1], [2, 2]], + [[1, 1], [2, 2]], + [[5, 5], [6, 6]], + [[5, 5], [6, 6]], + [[9, 9], [10, 10]], + [[9, 9], [10, 10]], + ] ) expect_labels = ["a", "a", "b", "b", "c", "c"] assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1)) @@ -483,13 +616,17 @@ def test_dataproto_chunk_after_index(): data_len = 4 obs = torch.randn(data_len, 4) labels = [f"label_{i}" for i in range(data_len)] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abc"}) + data = DataProto.from_dict( + tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abc"} + ) # Test with boolean numpy array bool_mask = np.array([True, False, True, False]) selected = data[bool_mask] assert isinstance(selected.batch.batch_size, torch.Size) - assert all(isinstance(d, int) for d in selected.batch.batch_size) # int or List[int] + assert all( + isinstance(d, int) for d in selected.batch.batch_size + ) # int or List[int] # Test with integer numpy array int_mask = np.array([0, 2]) diff --git a/Agent0/executor_train/verl/tests/tools/test_base_tool_on_cpu.py b/Agent0/executor_train/verl/tests/tools/test_base_tool_on_cpu.py index 63a2bbb..abf4977 100644 --- a/Agent0/executor_train/verl/tests/tools/test_base_tool_on_cpu.py +++ b/Agent0/executor_train/verl/tests/tools/test_base_tool_on_cpu.py @@ -44,7 +44,9 @@ def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: schema = get_json_schema(self.get_current_temperature) return OpenAIFunctionToolSchema(**schema) - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: try: result = self.get_current_temperature(**parameters) return json.dumps(result), 0, {} @@ -75,7 +77,9 @@ def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): "unit": unit, } - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: try: result = self.get_temperature_date(**parameters) return json.dumps(result), 0, {} @@ -152,7 +156,10 @@ def test_initialize_tools_from_local_config(create_local_tool_config): tools = initialize_tools_from_config(tool_config_path) assert len(tools) == 2 - from tests.tools.test_base_tool_on_cpu import WeatherToolForTest, WeatherToolWithDataForTest + from tests.tools.test_base_tool_on_cpu import ( + WeatherToolForTest, + WeatherToolWithDataForTest, + ) assert isinstance(tools[0], WeatherToolForTest) assert isinstance(tools[1], WeatherToolWithDataForTest) diff --git a/Agent0/executor_train/verl/tests/trainer/config/test_algo_config_on_cpu.py b/Agent0/executor_train/verl/tests/trainer/config/test_algo_config_on_cpu.py index 848a3ff..afeee14 100644 --- a/Agent0/executor_train/verl/tests/trainer/config/test_algo_config_on_cpu.py +++ b/Agent0/executor_train/verl/tests/trainer/config/test_algo_config_on_cpu.py @@ -49,7 +49,11 @@ def setUp(self): "target_kl": 0.05, }, "use_pf_ppo": True, - "pf_ppo": {"_target_": "verl.trainer.config.PFPPOConfig", "reweight_method": "max_min", "weight_pow": 3.0}, + "pf_ppo": { + "_target_": "verl.trainer.config.PFPPOConfig", + "reweight_method": "max_min", + "weight_pow": 3.0, + }, } self.omega_config = OmegaConf.create(self.config_dict) @@ -151,7 +155,9 @@ def setUp(self): norm_adv_by_std_in_grpo=True, use_kl_in_reward=True, kl_penalty="kl", - kl_ctrl=KLControlConfig(type="adaptive", kl_coef=0.002, horizon=5000, target_kl=0.05), + kl_ctrl=KLControlConfig( + type="adaptive", kl_coef=0.002, horizon=5000, target_kl=0.05 + ), use_pf_ppo=True, pf_ppo=PFPPOConfig(reweight_method="max_min", weight_pow=3.0), ) @@ -187,7 +193,9 @@ def test_grpo_advantage_estimator_with_cfg(self): # Test GRPO advantage computation batch_size, seq_len = 4, 3 - token_level_rewards = torch.tensor([[1.0, 0.5, 0.0], [2.0, 1.0, 0.0], [0.5, 0.2, 0.0], [1.5, 0.8, 0.0]]) + token_level_rewards = torch.tensor( + [[1.0, 0.5, 0.0], [2.0, 1.0, 0.0], [0.5, 0.2, 0.0], [1.5, 0.8, 0.0]] + ) response_mask = torch.ones(batch_size, seq_len) index = np.array([0, 0, 1, 1]) # Two groups diff --git a/Agent0/executor_train/verl/tests/trainer/config/test_legacy_config_on_cpu.py b/Agent0/executor_train/verl/tests/trainer/config/test_legacy_config_on_cpu.py index 39862aa..e79b6ae 100644 --- a/Agent0/executor_train/verl/tests/trainer/config/test_legacy_config_on_cpu.py +++ b/Agent0/executor_train/verl/tests/trainer/config/test_legacy_config_on_cpu.py @@ -23,7 +23,9 @@ class TestConfigComparison(unittest.TestCase): """Test that current configs match their legacy counterparts exactly.""" - def _compare_configs_recursively(self, current_config, legacy_config, path="", legacy_allow_missing=True): + def _compare_configs_recursively( + self, current_config, legacy_config, path="", legacy_allow_missing=True + ): """Recursively compare two OmegaConf configs and assert they are identical. Args: @@ -38,7 +40,9 @@ def _compare_configs_recursively(self, current_config, legacy_config, path="", l missing_in_legacy = current_keys - legacy_keys if missing_in_current: - self.fail(f"Keys missing in current config at {path}: {missing_in_current}") + self.fail( + f"Keys missing in current config at {path}: {missing_in_current}" + ) if missing_in_legacy: # if the legacy msg = f"Keys missing in legacy config at {path}: {missing_in_legacy}" @@ -50,15 +54,21 @@ def _compare_configs_recursively(self, current_config, legacy_config, path="", l for key in current_keys: current_path = f"{path}.{key}" if path else key if key in legacy_config: - self._compare_configs_recursively(current_config[key], legacy_config[key], current_path) + self._compare_configs_recursively( + current_config[key], legacy_config[key], current_path + ) elif isinstance(current_config, list) and isinstance(legacy_config, list): self.assertEqual( len(current_config), len(legacy_config), f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}", ) - for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)): - self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]") + for i, (current_item, legacy_item) in enumerate( + zip(current_config, legacy_config, strict=True) + ): + self._compare_configs_recursively( + current_item, legacy_item, f"{path}[{i}]" + ) else: self.assertEqual( current_config, @@ -76,10 +86,14 @@ def test_ppo_trainer_config_matches_legacy(self): GlobalHydra.instance().clear() try: - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + with initialize_config_dir( + config_dir=os.path.abspath("verl/trainer/config") + ): current_config = compose(config_name="ppo_trainer") - legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_trainer.yaml") + legacy_config = OmegaConf.load( + "tests/trainer/config/legacy_ppo_trainer.yaml" + ) current_dict = OmegaConf.to_container(current_config, resolve=True) legacy_dict = OmegaConf.to_container(legacy_config, resolve=True) @@ -96,17 +110,23 @@ def test_ppo_megatron_trainer_config_matches_legacy(self): GlobalHydra.instance().clear() try: - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + with initialize_config_dir( + config_dir=os.path.abspath("verl/trainer/config") + ): current_config = compose(config_name="ppo_megatron_trainer") - legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_megatron_trainer.yaml") + legacy_config = OmegaConf.load( + "tests/trainer/config/legacy_ppo_megatron_trainer.yaml" + ) current_dict = OmegaConf.to_container(current_config, resolve=True) legacy_dict = OmegaConf.to_container(legacy_config, resolve=True) if "defaults" in current_dict: del current_dict["defaults"] - self._compare_configs_recursively(current_dict, legacy_dict, legacy_allow_missing=True) + self._compare_configs_recursively( + current_dict, legacy_dict, legacy_allow_missing=True + ) finally: GlobalHydra.instance().clear() diff --git a/Agent0/executor_train/verl/tests/trainer/ppo/test_core_algos_on_cpu.py b/Agent0/executor_train/verl/tests/trainer/ppo/test_core_algos_on_cpu.py index 087a0d2..8efd91b 100644 --- a/Agent0/executor_train/verl/tests/trainer/ppo/test_core_algos_on_cpu.py +++ b/Agent0/executor_train/verl/tests/trainer/ppo/test_core_algos_on_cpu.py @@ -19,7 +19,11 @@ import torch import verl.trainer.ppo.core_algos -from verl.trainer.ppo.core_algos import compute_gae_advantage_return, get_adv_estimator_fn, register_adv_est +from verl.trainer.ppo.core_algos import ( + compute_gae_advantage_return, + get_adv_estimator_fn, + register_adv_est, +) def mock_test_fn(): @@ -136,7 +140,9 @@ def test_multi_turn_compute_gae_advantage_return(): gamma = random.uniform(0.0, 1.0) lam = random.uniform(0.0, 1.0) - rewards = torch.tensor([[0.0, 0.0, 0.1, 0.1, 0.1, 0.0, 0.0, 0.1, 1.0, 0.0, 0.0]], dtype=torch.float) + rewards = torch.tensor( + [[0.0, 0.0, 0.1, 0.1, 0.1, 0.0, 0.0, 0.1, 1.0, 0.0, 0.0]], dtype=torch.float + ) values1 = torch.tensor( [ @@ -178,8 +184,12 @@ def test_multi_turn_compute_gae_advantage_return(): response_mask = torch.tensor([[0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float) - adv1, ret1 = compute_gae_advantage_return(rewards, values1, response_mask, gamma, lam) - adv2, ret2 = compute_gae_advantage_return(rewards, values2, response_mask, gamma, lam) + adv1, ret1 = compute_gae_advantage_return( + rewards, values1, response_mask, gamma, lam + ) + adv2, ret2 = compute_gae_advantage_return( + rewards, values2, response_mask, gamma, lam + ) ret1 *= response_mask ret2 *= response_mask diff --git a/Agent0/executor_train/verl/tests/trainer/ppo/test_metric_utils_on_cpu.py b/Agent0/executor_train/verl/tests/trainer/ppo/test_metric_utils_on_cpu.py index 50fe952..3b4e67c 100644 --- a/Agent0/executor_train/verl/tests/trainer/ppo/test_metric_utils_on_cpu.py +++ b/Agent0/executor_train/verl/tests/trainer/ppo/test_metric_utils_on_cpu.py @@ -110,8 +110,12 @@ def test_compute_data_metrics_with_critic(self): self.assertIn("prompt_length/mean", metrics) # Check some specific values - self.assertAlmostEqual(metrics["critic/score/mean"], 5.0) # Sum of token_level_scores - self.assertAlmostEqual(metrics["critic/rewards/mean"], 2.5) # Sum of token_level_rewards + self.assertAlmostEqual( + metrics["critic/score/mean"], 5.0 + ) # Sum of token_level_scores + self.assertAlmostEqual( + metrics["critic/rewards/mean"], 2.5 + ) # Sum of token_level_rewards def test_compute_data_metrics_without_critic(self): """Test compute_data_metrics with critic disabled.""" @@ -171,11 +175,17 @@ def test_compute_timing_metrics(self, mock_compute_response_info): # Check per-token timing metrics # gen uses only response tokens (6 tokens) - self.assertAlmostEqual(metrics["timing_per_token_ms/gen"], 0.5 * 1000 / 6, places=5) + self.assertAlmostEqual( + metrics["timing_per_token_ms/gen"], 0.5 * 1000 / 6, places=5 + ) # ref and values use all tokens (12 tokens) - self.assertAlmostEqual(metrics["timing_per_token_ms/ref"], 0.3 * 1000 / 12, places=5) - self.assertAlmostEqual(metrics["timing_per_token_ms/values"], 0.2 * 1000 / 12, places=5) + self.assertAlmostEqual( + metrics["timing_per_token_ms/ref"], 0.3 * 1000 / 12, places=5 + ) + self.assertAlmostEqual( + metrics["timing_per_token_ms/values"], 0.2 * 1000 / 12, places=5 + ) class TestComputeThroughputMetrics(unittest.TestCase): @@ -207,7 +217,9 @@ def test_compute_throughout_metrics(self): self.assertEqual(metrics["perf/total_num_tokens"], 600) self.assertEqual(metrics["perf/time_per_step"], 2.0) - self.assertEqual(metrics["perf/throughput"], 600 / (2.0 * 2)) # 150 tokens/sec/GPU + self.assertEqual( + metrics["perf/throughput"], 600 / (2.0 * 2) + ) # 150 tokens/sec/GPU class TestBootstrapMetric(unittest.TestCase): @@ -219,7 +231,9 @@ def test_bootstrap_metric_basic(self): reduce_fns = [np.mean, np.max] # Use a fixed seed for reproducibility - result = bootstrap_metric(data, subset_size=3, reduce_fns=reduce_fns, n_bootstrap=100, seed=42) + result = bootstrap_metric( + data, subset_size=3, reduce_fns=reduce_fns, n_bootstrap=100, seed=42 + ) # Check that we get two results (one for each reduce_fn) self.assertEqual(len(result), 2) @@ -287,7 +301,9 @@ def test_process_validation_metrics_basic(self): "score": [0.8, 0.9, 0.7], } - result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42) + result = process_validation_metrics( + data_sources, sample_inputs, infos_dict, seed=42 + ) # Check the structure of the result self.assertIn("source1", result) @@ -311,7 +327,9 @@ def test_process_validation_metrics_with_pred(self): "pred": ["A", "B", "A"], } - result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42) + result = process_validation_metrics( + data_sources, sample_inputs, infos_dict, seed=42 + ) # Check that majority voting metrics are present self.assertIn("maj@2/mean", result["source1"]["score"]) diff --git a/Agent0/executor_train/verl/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py b/Agent0/executor_train/verl/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py index 203494b..5ab7955 100644 --- a/Agent0/executor_train/verl/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py +++ b/Agent0/executor_train/verl/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py @@ -29,13 +29,17 @@ def test_no_expiration_timestamp(self): def test_mlp_expiration_valid(self): """Test valid MLP expiration timestamp requiring save""" current_time = time.time() - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 90) + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str( + current_time + 90 + ) self.assertTrue(should_save_ckpt_esi(30)) # max_steps_duration=30 seconds def test_mlp_expiration_passed(self): """Test expired MLP timestamp""" current_time = time.time() - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time - 10) + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str( + current_time - 10 + ) self.assertFalse(should_save_ckpt_esi(30)) def test_mlp_invalid_timestamp(self): @@ -46,25 +50,33 @@ def test_mlp_invalid_timestamp(self): def test_mlp_expiration_not_reached(self): """Test MLP expiration timestamp with insufficient remaining time""" current_time = time.time() - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 200) + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str( + current_time + 200 + ) self.assertFalse(should_save_ckpt_esi(30)) # max_steps_duration=30 def test_aws_expiration_not_reached(self): """Test AWS expiration timestamp with sufficient remaining time""" now = datetime.now() expiration = now + timedelta(minutes=100) # Exceeds 90-minute threshold - os.environ["SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(int(expiration.timestamp())) + os.environ["SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str( + int(expiration.timestamp()) + ) self.assertFalse(should_save_ckpt_esi(30 * 60)) def test_redundant_time(self): """Test redundant_time parameter effect""" current_time = time.time() # Total required: 60+30+30=120 seconds - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 120) + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str( + current_time + 120 + ) self.assertTrue(should_save_ckpt_esi(30, redundant_time=30)) def test_zero_max_steps_duration(self): """Test zero max_steps_duration""" current_time = time.time() - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 60) + os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str( + current_time + 60 + ) self.assertFalse(should_save_ckpt_esi(0)) diff --git a/Agent0/executor_train/verl/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py b/Agent0/executor_train/verl/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py index 8028d44..5f1c5c5 100644 --- a/Agent0/executor_train/verl/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py +++ b/Agent0/executor_train/verl/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py @@ -56,8 +56,14 @@ def test_multiturn_sft_dataset(): # Initialize tokenizer and dataset tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct") - config = {"max_length": 512, "truncation": "error", "multiturn": {"messages_key": "messages"}} - dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config) + config = { + "max_length": 512, + "truncation": "error", + "multiturn": {"messages_key": "messages"}, + } + dataset = MultiTurnSFTDataset( + parquet_files=test_file, tokenizer=tokenizer, config=config + ) # Test 1: Dataset Length assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}" @@ -71,14 +77,20 @@ def test_multiturn_sft_dataset(): for key in required_keys: assert key in item0, f"Missing key {key} in dataset item" assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}" - assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}" + assert ( + item0[key].dtype == torch.long + ), f"Expected torch.long for {key}, got {item0[key].dtype}" # Test 3: Shape Consistency - assert item0["loss_mask"].shape == item0["input_ids"].shape, "Loss mask shape doesn't match input_ids shape" - assert item0["attention_mask"].shape == item0["input_ids"].shape, ( - "Attention mask shape doesn't match input_ids shape" - ) - assert item0["position_ids"].shape == item0["input_ids"].shape, "Position IDs shape doesn't match input_ids shape" + assert ( + item0["loss_mask"].shape == item0["input_ids"].shape + ), "Loss mask shape doesn't match input_ids shape" + assert ( + item0["attention_mask"].shape == item0["input_ids"].shape + ), "Attention mask shape doesn't match input_ids shape" + assert ( + item0["position_ids"].shape == item0["input_ids"].shape + ), "Position IDs shape doesn't match input_ids shape" # Test 4: Loss Mask Pattern - Math Conversation loss_mask0 = item0["loss_mask"] @@ -105,24 +117,32 @@ def test_multiturn_sft_dataset(): # Decode and verify assistant responses assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1]) print(f"Joke conversation assistant text: {assistant_text1}") - assert "chicken cross the road" in assistant_text1, "First assistant response not found" + assert ( + "chicken cross the road" in assistant_text1 + ), "First assistant response not found" assert "other side" in assistant_text1, "Second assistant response not found" # Test 6: Attention Mask Pattern attention_mask0 = item0["attention_mask"] sequence_length = torch.sum(attention_mask0) assert sequence_length > 0, "No tokens marked as attended in attention mask" - assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern" + assert torch.all( + attention_mask0[:sequence_length] == 1 + ), "Incorrect attention mask pattern" if sequence_length < len(attention_mask0): - assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked" + assert torch.all( + attention_mask0[sequence_length:] == 0 + ), "Padding not properly masked" # Test 7: Position IDs Pattern position_ids0 = item0["position_ids"] - assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), ( - "Position IDs not sequential for non-padded tokens" - ) + assert torch.equal( + position_ids0[:sequence_length], torch.arange(sequence_length) + ), "Position IDs not sequential for non-padded tokens" if sequence_length < len(position_ids0): - assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" + assert torch.all( + position_ids0[sequence_length:] == 0 + ), "Padding position IDs not zero" # Test 8: Verify loss mask for assistant responses # Get the full conversation text @@ -137,13 +157,15 @@ def test_multiturn_sft_dataset(): for msg in test_data["messages"][0]: # First conversation if msg["role"] == "assistant": # The content should appear in the masked text - assert msg["content"] in assistant_text, f"Assistant message '{msg['content']}' not found in masked text" + assert ( + msg["content"] in assistant_text + ), f"Assistant message '{msg['content']}' not found in masked text" # The content should NOT appear in the non-masked text non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) - assert msg["content"] not in non_assistant_text, ( - f"Assistant message '{msg['content']}' found in non-assistant text" - ) + assert ( + msg["content"] not in non_assistant_text + ), f"Assistant message '{msg['content']}' found in non-assistant text" # Test 9: Verify non-assistant parts have loss_mask=0 # Get non-assistant text @@ -153,29 +175,39 @@ def test_multiturn_sft_dataset(): # Verify that system and user messages are in the non-assistant text for msg in test_data["messages"][0]: # First conversation if msg["role"] in ["system", "user"]: - assert msg["content"] in non_assistant_text, ( - f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" - ) + assert ( + msg["content"] in non_assistant_text + ), f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" # And verify they're NOT in the assistant text - assert msg["content"] not in assistant_text, ( - f"{msg['role'].title()} message '{msg['content']}' found in assistant text" - ) + assert ( + msg["content"] not in assistant_text + ), f"{msg['role'].title()} message '{msg['content']}' found in assistant text" # Test 10: Verify padding behavior - padding_config = {"max_length": 1024, "truncation": "error", "multiturn": {"messages_key": "messages"}} - small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config) + padding_config = { + "max_length": 1024, + "truncation": "error", + "multiturn": {"messages_key": "messages"}, + } + small_dataset = MultiTurnSFTDataset( + parquet_files=test_file, tokenizer=tokenizer, config=padding_config + ) padded_item = small_dataset[0] # Get actual sequence length (before padding) actual_length = torch.sum(padded_item["attention_mask"]) # Verify padding tokens - assert torch.all(padded_item["input_ids"][actual_length:] == tokenizer.pad_token_id), ( - "Padding tokens not set correctly" - ) - assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding" - assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding" + assert torch.all( + padded_item["input_ids"][actual_length:] == tokenizer.pad_token_id + ), "Padding tokens not set correctly" + assert torch.all( + padded_item["attention_mask"][actual_length:] == 0 + ), "Attention mask not set correctly for padding" + assert torch.all( + padded_item["loss_mask"][actual_length:] == 0 + ), "Loss mask not set correctly for padding" print("All tests passed!") print("Starting test...") diff --git a/Agent0/executor_train/verl/tests/utils/dataset/test_rl_dataset_on_cpu.py b/Agent0/executor_train/verl/tests/utils/dataset/test_rl_dataset_on_cpu.py index 2afc3ef..6a27e8f 100644 --- a/Agent0/executor_train/verl/tests/utils/dataset/test_rl_dataset_on_cpu.py +++ b/Agent0/executor_train/verl/tests/utils/dataset/test_rl_dataset_on_cpu.py @@ -42,7 +42,13 @@ def test_rl_dataset(): ) dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config) - dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) + dataloader = DataLoader( + dataset=dataset, + batch_size=16, + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) a = next(iter(dataloader)) @@ -87,7 +93,13 @@ def test_image_rl_data(): processor=processor, ) - dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) + dataloader = DataLoader( + dataset=dataset, + batch_size=16, + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) a = next(iter(dataloader)) diff --git a/Agent0/executor_train/verl/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py b/Agent0/executor_train/verl/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py index 997cb8a..9a9d3bb 100644 --- a/Agent0/executor_train/verl/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py +++ b/Agent0/executor_train/verl/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py @@ -63,7 +63,10 @@ """ # --- Test input/output data --- -INPUT_OUTPUT_VALID = {"inputs": ["input1", "input2"], "outputs": ["output1\n", "output2\n"]} +INPUT_OUTPUT_VALID = { + "inputs": ["input1", "input2"], + "outputs": ["output1\n", "output2\n"], +} INPUT_OUTPUT_SINGLE = {"inputs": ["input1"], "outputs": ["output1\n"]} @@ -77,7 +80,9 @@ @pytest.mark.skipif(skip_condition, reason=skip_reason) def test_integration_success_correct(): """Integration test: Code is correct, output is correct""" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_SUCCESS) + results, metadata_list = check_correctness( + SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_SUCCESS + ) assert results == [True, True] assert metadata_list[0]["status"] == "success" assert metadata_list[0]["stdout"] == "output1\n" @@ -88,7 +93,9 @@ def test_integration_success_correct(): @pytest.mark.skipif(skip_condition, reason=skip_reason) def test_integration_success_wrong_output(): """Integration test: Code runs successfully, but output is wrong""" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_WRONG_OUTPUT) + results, metadata_list = check_correctness( + SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_WRONG_OUTPUT + ) assert results == [False, False] assert metadata_list[0]["status"] == "wrong_answer" assert metadata_list[0]["stdout"] == "wrong_output\n" @@ -98,7 +105,9 @@ def test_integration_success_wrong_output(): @pytest.mark.skipif(skip_condition, reason=skip_reason) def test_integration_compile_error(): """Integration test: Code causes compile error""" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_COMPILE_ERROR, language="cpp") + results, metadata_list = check_correctness( + SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_COMPILE_ERROR, language="cpp" + ) assert results == [-4, -4] assert metadata_list[0]["status"] == "compile_error" assert metadata_list[1]["status"] == "compile_error" @@ -107,7 +116,9 @@ def test_integration_compile_error(): @pytest.mark.skipif(skip_condition, reason=skip_reason) def test_integration_runtime_error(): """Integration test: Code causes runtime error""" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_RUNTIME_ERROR) + results, metadata_list = check_correctness( + SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_RUNTIME_ERROR + ) assert results == [-2] assert metadata_list[0]["status"] == "runtime_error" # More assertions can be added based on the actual API response, e.g., exit_code, stderr @@ -117,7 +128,9 @@ def test_integration_runtime_error(): def test_integration_runtime_timeout(): """Integration test: Code causes runtime timeout""" test_timeout = 5 # Set a timeout shorter than the sleep time in CODE_TIMEOUT - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_TIMEOUT, timeout=test_timeout) + results, metadata_list = check_correctness( + SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_TIMEOUT, timeout=test_timeout + ) assert results == [-3] assert metadata_list[0]["status"] == "timeout" # More assertions can be added based on the actual API response, e.g., run_status @@ -188,7 +201,9 @@ def test_integration_concurrency_high_load(): ) # Verify results against the expected map - assert len(results) == concurrency_level, f"Expected {concurrency_level} results, got {len(results)}" + assert ( + len(results) == concurrency_level + ), f"Expected {concurrency_level} results, got {len(results)}" correct_count = 0 wrong_count = 0 @@ -210,35 +225,55 @@ def test_integration_concurrency_high_load(): f"Correct results (True): {correct_count}/" f"{concurrency_level - len(wrong_answer_indices) - len(timeout_indices)}" ) - print(f"Expected wrong answers (False, correctly identified): {wrong_count}/{len(wrong_answer_indices)}") - print(f"Expected timeouts (-3, correctly identified): {timeout_count}/{len(timeout_indices)}") + print( + f"Expected wrong answers (False, correctly identified): {wrong_count}/{len(wrong_answer_indices)}" + ) + print( + f"Expected timeouts (-3, correctly identified): {timeout_count}/{len(timeout_indices)}" + ) if unexpected_results: print("Unexpected results found:") - for idx, res, expected_str in unexpected_results[:10]: # Print first 10 unexpected - print(f" Index {idx}: Got {res}, {expected_str}. Metadata: {metadata_list[idx]}") + for idx, res, expected_str in unexpected_results[ + :10 + ]: # Print first 10 unexpected + print( + f" Index {idx}: Got {res}, {expected_str}. Metadata: {metadata_list[idx]}" + ) raise AssertionError(f"Found {len(unexpected_results)} unexpected results.") - assert correct_count == concurrency_level - len(wrong_answer_indices) - len(timeout_indices), ( - "Incorrect number of successful results" - ) - assert wrong_count == len(wrong_answer_indices), "Incorrect number of identified wrong answers" - assert timeout_count == len(timeout_indices), "Incorrect number of identified timeouts" + assert correct_count == concurrency_level - len(wrong_answer_indices) - len( + timeout_indices + ), "Incorrect number of successful results" + assert wrong_count == len( + wrong_answer_indices + ), "Incorrect number of identified wrong answers" + assert timeout_count == len( + timeout_indices + ), "Incorrect number of identified timeouts" # Verify metadata count and basic status of one of each type assert len(metadata_list) == concurrency_level # Find the first correct index first_correct_index = next( - i for i in range(concurrency_level) if i not in wrong_answer_indices and i not in timeout_indices + i + for i in range(concurrency_level) + if i not in wrong_answer_indices and i not in timeout_indices ) assert metadata_list[first_correct_index]["status"] == "success" - assert metadata_list[first_correct_index]["stdout"] == f"output_{first_correct_index}\n" + assert ( + metadata_list[first_correct_index]["stdout"] + == f"output_{first_correct_index}\n" + ) # Check the status of the first intentionally wrong case first_wrong_index = min(wrong_answer_indices) assert metadata_list[first_wrong_index]["status"] == "wrong_answer" assert metadata_list[first_wrong_index]["stdout"] == f"output_{first_wrong_index}\n" - assert metadata_list[first_wrong_index]["expected_output"] == f"wrong_output_{first_wrong_index}\n" + assert ( + metadata_list[first_wrong_index]["expected_output"] + == f"wrong_output_{first_wrong_index}\n" + ) # Check the status of the first intentionally timeout case first_timeout_index = min(timeout_indices) @@ -256,24 +291,48 @@ def test_unit_concurrency_order(mock_call_sandbox_api): generation = "print(input())" language = "python" timeout = 5 - in_outs = {"inputs": ["input1", "input2", "input3"], "outputs": ["output1", "output2", "output3"]} + in_outs = { + "inputs": ["input1", "input2", "input3"], + "outputs": ["output1", "output2", "output3"], + } def side_effect(*args, **kwargs): stdin = kwargs.get("stdin") if stdin == "input1": return ( - {"status": "Success", "run_result": {"status": "Finished", "stdout": "output1", "return_code": 0}}, + { + "status": "Success", + "run_result": { + "status": "Finished", + "stdout": "output1", + "return_code": 0, + }, + }, None, ) elif stdin == "input2": time.sleep(0.1) return ( - {"status": "Success", "run_result": {"status": "Finished", "stdout": "output2", "return_code": 0}}, + { + "status": "Success", + "run_result": { + "status": "Finished", + "stdout": "output2", + "return_code": 0, + }, + }, None, ) elif stdin == "input3": return ( - {"status": "Success", "run_result": {"status": "Finished", "stdout": "output3", "return_code": 0}}, + { + "status": "Success", + "run_result": { + "status": "Finished", + "stdout": "output3", + "return_code": 0, + }, + }, None, ) else: @@ -281,7 +340,9 @@ def side_effect(*args, **kwargs): mock_call_sandbox_api.side_effect = side_effect - results, metadata_list = check_correctness(sandbox_url, in_outs, generation, timeout, language) + results, metadata_list = check_correctness( + sandbox_url, in_outs, generation, timeout, language + ) assert results == [True, True, True] assert len(metadata_list) == 3 @@ -300,7 +361,10 @@ def test_unit_api_timeout_error_concurrent(mock_call_sandbox_api): generation = "print(input())" language = "python" timeout = 5 - in_outs = {"inputs": ["input1", "input2_timeout", "input3"], "outputs": ["output1", "output2", "output3"]} + in_outs = { + "inputs": ["input1", "input2_timeout", "input3"], + "outputs": ["output1", "output2", "output3"], + } api_error_message = "API Call Failed: Gateway Timeout (504) on attempt 3/3" @@ -308,14 +372,28 @@ def side_effect(*args, **kwargs): stdin = kwargs.get("stdin") if stdin == "input1": return ( - {"status": "Success", "run_result": {"status": "Finished", "stdout": "output1", "return_code": 0}}, + { + "status": "Success", + "run_result": { + "status": "Finished", + "stdout": "output1", + "return_code": 0, + }, + }, None, ) elif stdin == "input2_timeout": return (None, api_error_message) elif stdin == "input3": return ( - {"status": "Success", "run_result": {"status": "Finished", "stdout": "output3", "return_code": 0}}, + { + "status": "Success", + "run_result": { + "status": "Finished", + "stdout": "output3", + "return_code": 0, + }, + }, None, ) else: @@ -323,7 +401,9 @@ def side_effect(*args, **kwargs): mock_call_sandbox_api.side_effect = side_effect - results, metadata_list = check_correctness(sandbox_url, in_outs, generation, timeout, language) + results, metadata_list = check_correctness( + sandbox_url, in_outs, generation, timeout, language + ) assert results == [True, -1, True] assert len(metadata_list) == 3 @@ -382,7 +462,11 @@ def _mock_api_call_for_concurrency_tracking( # Return a simulated successful API response return { "status": "Success", - "run_result": {"status": "Finished", "stdout": f"mock_output_for_{stdin}", "return_code": 0}, + "run_result": { + "status": "Finished", + "stdout": f"mock_output_for_{stdin}", + "return_code": 0, + }, }, None @@ -401,20 +485,18 @@ def _process_pool_worker_for_concurrency_test( call_lock, ): # Corrected lambda to accept keyword arguments matching call_sandbox_api's usage - curried_mock_api_call = ( - lambda sandbox_fusion_url, code, stdin, compile_timeout, run_timeout, memory_limit_mb, language: ( - _mock_api_call_for_concurrency_tracking( - active_calls_counter, - max_calls_tracker, - call_lock, - sandbox_fusion_url, - code, - stdin, - compile_timeout, - run_timeout, - memory_limit_mb, - language, - ) + curried_mock_api_call = lambda sandbox_fusion_url, code, stdin, compile_timeout, run_timeout, memory_limit_mb, language: ( + _mock_api_call_for_concurrency_tracking( + active_calls_counter, + max_calls_tracker, + call_lock, + sandbox_fusion_url, + code, + stdin, + compile_timeout, + run_timeout, + memory_limit_mb, + language, ) ) @@ -431,7 +513,8 @@ def _process_pool_worker_for_concurrency_test( # ---- END DEBUG PRINTS ---- with patch( - "verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api", side_effect=curried_mock_api_call + "verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api", + side_effect=curried_mock_api_call, ) as mock_obj: # ---- START DEBUG PRINTS ---- print( @@ -464,7 +547,9 @@ def test_multiprocess_global_concurrency_limit_with_semaphore(): """ manager = multiprocessing.Manager() active_calls_counter = manager.Value("i", 0) # Current active mock API calls - max_calls_tracker = manager.Value("i", 0) # Observed maximum concurrent mock API calls + max_calls_tracker = manager.Value( + "i", 0 + ) # Observed maximum concurrent mock API calls call_lock = manager.Lock() # Lock to protect counters # Create a multiprocessing.Semaphore instance, this is the global semaphore we are testing. @@ -472,7 +557,9 @@ def test_multiprocess_global_concurrency_limit_with_semaphore(): global_mp_semaphore = manager.Semaphore(MAX_GLOBAL_CONCURRENCY_LIMIT_TEST) mock_sandbox_url = "mock_url_for_concurrency_test" - mock_generation = "pass" # Specific code content is not important as API call is mocked + mock_generation = ( + "pass" # Specific code content is not important as API call is mocked + ) mock_memory_limit_mb = 1024 mock_language = "python" mock_timeout = 5 # Timeout setting, not critical for mock calls @@ -513,9 +600,13 @@ def test_multiprocess_global_concurrency_limit_with_semaphore(): # Print some test statistics for debugging and validation print("\n--- Global Concurrency Test Stats ---") - print(f"Semaphore Limit (MAX_GLOBAL_CONCURRENCY_LIMIT_TEST): {MAX_GLOBAL_CONCURRENCY_LIMIT_TEST}") + print( + f"Semaphore Limit (MAX_GLOBAL_CONCURRENCY_LIMIT_TEST): {MAX_GLOBAL_CONCURRENCY_LIMIT_TEST}" + ) print(f"Number of Processes (NUM_PROCESSES_TEST): {NUM_PROCESSES_TEST}") - print(f"Tasks per Process (NUM_TASKS_PER_PROCESS_TEST): {NUM_TASKS_PER_PROCESS_TEST}") + print( + f"Tasks per Process (NUM_TASKS_PER_PROCESS_TEST): {NUM_TASKS_PER_PROCESS_TEST}" + ) print(f"Total Tasks Submitted: {total_tasks_expected_to_run}") print(f"Simulated API Call Duration: {SIMULATED_API_CALL_DURATION_TEST}s") print(f"Total Test Execution Time: {total_execution_time:.2f}s") @@ -523,12 +614,14 @@ def test_multiprocess_global_concurrency_limit_with_semaphore(): # print(f"Tasks processed per worker: {num_tasks_processed_per_worker}") # Verify that all submitted tasks have been processed - assert sum(num_tasks_processed_per_worker) == total_tasks_expected_to_run, ( - "Mismatch in the number of tasks processed." - ) + assert ( + sum(num_tasks_processed_per_worker) == total_tasks_expected_to_run + ), "Mismatch in the number of tasks processed." # Verify that the mock API was called at least once - assert max_calls_tracker.value > 0, "The mocked API call_sandbox_api was not called." + assert ( + max_calls_tracker.value > 0 + ), "The mocked API call_sandbox_api was not called." # Core assertion: Observed maximum concurrent calls should not exceed the semaphore's limit assert max_calls_tracker.value <= MAX_GLOBAL_CONCURRENCY_LIMIT_TEST, ( @@ -563,7 +656,9 @@ def test_unit_invalid_input_format(): assert results == [-1] assert metadata_list[0]["error"] == "Invalid input/output data" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_INVALID_MISSING_KEY, CODE_SUCCESS) + results, metadata_list = check_correctness( + SANDBOX_URL, INPUT_OUTPUT_INVALID_MISSING_KEY, CODE_SUCCESS + ) assert results == [-1] assert metadata_list[0]["error"] == "Invalid input/output data" @@ -571,7 +666,9 @@ def test_unit_invalid_input_format(): @pytest.mark.skipif(skip_condition, reason=skip_reason) def test_unit_input_output_mismatch(): """Unit test: Mismatch between the number of inputs and outputs""" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_MISMATCH, CODE_SUCCESS) + results, metadata_list = check_correctness( + SANDBOX_URL, INPUT_OUTPUT_MISMATCH, CODE_SUCCESS + ) assert results == [-1] assert len(metadata_list) == 1 assert metadata_list[0]["error"] == "Input/output count mismatch" @@ -608,13 +705,19 @@ def solve(): test_timeout = 10 # Set a timeout value start_time = time.time() - results, metadata_list = check_correctness(SANDBOX_URL, timeout_in_outs, code_infinite_loop, timeout=test_timeout) + results, metadata_list = check_correctness( + SANDBOX_URL, timeout_in_outs, code_infinite_loop, timeout=test_timeout + ) end_time = time.time() duration = end_time - start_time - print(f"\nHigh concurrency all timeout test ({concurrency_level} cases) duration: {duration:.2f} seconds") + print( + f"\nHigh concurrency all timeout test ({concurrency_level} cases) duration: {duration:.2f} seconds" + ) # Verify all results are -3 (timeout) - assert len(results) == concurrency_level, f"Expected {concurrency_level} results, got {len(results)}" + assert ( + len(results) == concurrency_level + ), f"Expected {concurrency_level} results, got {len(results)}" all_timed_out = all(r == -3 for r in results) if not all_timed_out: non_timeout_indices = [i for i, r in enumerate(results) if r != -3] @@ -622,7 +725,9 @@ def solve(): # Print metadata for the first few non-timeout cases for debugging for i in non_timeout_indices[:5]: print(f"Metadata for non-timeout case {i}: {metadata_list[i]}") - assert all_timed_out, f"Not all {concurrency_level} concurrent tests resulted in timeout (-3). Results: {results}" + assert ( + all_timed_out + ), f"Not all {concurrency_level} concurrent tests resulted in timeout (-3). Results: {results}" # Verify metadata count and status of the first case assert len(metadata_list) == concurrency_level @@ -657,7 +762,9 @@ def occurrencesOfElement(self, nums: List[int], queries: List[int], x: int) -> L } # Use a short timeout for fast tests - results, metadata_list = check_correctness(SANDBOX_URL, in_outs, generation_code, timeout=5) + results, metadata_list = check_correctness( + SANDBOX_URL, in_outs, generation_code, timeout=5 + ) # from verl.utils.reward_score.prime_code import apps_check_correctness # results, metadata_list = apps_check_correctness(in_outs=in_outs, generation=generation_code, # timeout=50000, debug=True) diff --git a/Agent0/executor_train/verl/tests/utils/reward_score/test_sandbox_on_cpu.py b/Agent0/executor_train/verl/tests/utils/reward_score/test_sandbox_on_cpu.py index ff40732..7876731 100644 --- a/Agent0/executor_train/verl/tests/utils/reward_score/test_sandbox_on_cpu.py +++ b/Agent0/executor_train/verl/tests/utils/reward_score/test_sandbox_on_cpu.py @@ -33,8 +33,9 @@ """(x + 2)^2 + (y - 3)^2 """, # symbolic test ] -prime_code_answers = [ - """import sys +prime_code_answers = ( + [ + """import sys from collections import deque def main(): @@ -84,7 +85,9 @@ def main(): if __name__ == '__main__': main() """ -] * 2 + ] + * 2 +) prime_code_gts = [ """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"2\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", # A correct sample # noqa: E501 """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", # noqa: E501 @@ -110,7 +113,13 @@ def test_parallelism(): data_sources.extend(["numina_aops_forum"] * len(prime_math_answers)) scores = asyncio.run( - parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16) + parallel_compute_score_async( + default_compute_score, + sequences_str, + ground_truth, + data_sources, + num_processes=16, + ) ) print(scores) @@ -120,13 +129,18 @@ def test_prime_code(): Test PRIME code sandbox. """ data_source = "codecontests" - for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True): + for completion, ground_truth, score_ in zip( + prime_code_answers, prime_code_gts, prime_code_scores, strict=True + ): score = default_compute_score(data_source, completion, ground_truth) assert float(score) == score_ # Use the pytest.mark.skipif decorator to skip the test -@pytest.mark.skipif(not os.environ.get("SANDBOX_FUSION_URL"), reason="SANDBOX_FUSION_URL environment variable not set") +@pytest.mark.skipif( + not os.environ.get("SANDBOX_FUSION_URL"), + reason="SANDBOX_FUSION_URL environment variable not set", +) def test_prime_code_sandbox_fusion(): """ Test PRIME code on sandbox fusion. Skips if SANDBOX_FUSION_URL is not set. @@ -136,14 +150,22 @@ def test_prime_code_sandbox_fusion(): sandbox_fusion_url = os.environ.get("SANDBOX_FUSION_URL") # Removed the previous 'if not sandbox_url' check block - for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True): + for completion, ground_truth, score_ in zip( + prime_code_answers, prime_code_gts, prime_code_scores, strict=True + ): score = default_compute_score( - data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url} + data_source, + completion, + ground_truth, + extra_info={"sandbox_fusion_url": sandbox_fusion_url}, ) # <-- Use the URL obtained from the environment variable assert float(score) == score_ -@pytest.mark.skipif(not os.environ.get("SANDBOX_FUSION_URL"), reason="SANDBOX_FUSION_URL environment variable not set") +@pytest.mark.skipif( + not os.environ.get("SANDBOX_FUSION_URL"), + reason="SANDBOX_FUSION_URL environment variable not set", +) def test_continuous_score_consistency(): """ Verify that continuous score calculation is consistent between prime_code and sandbox_fusion. @@ -155,12 +177,18 @@ def test_continuous_score_consistency(): # 1. Calculate score using prime_code (default) with continuous=True prime_score, _ = sandbox_fusion.compute_score( - os.environ.get("SANDBOX_FUSION_URL"), None, completion, ground_truth, continuous=True + os.environ.get("SANDBOX_FUSION_URL"), + None, + completion, + ground_truth, + continuous=True, ) # 2. Calculate score using sandbox_fusion with continuous=True # Ensure the extra_info key triggers the sandbox_fusion path in default_compute_score - fusion_score, _ = prime_code.compute_score(completion, ground_truth, continuous=True) + fusion_score, _ = prime_code.compute_score( + completion, ground_truth, continuous=True + ) # 3. Assert scores are equal (using pytest.approx for float comparison) assert float(prime_score) == pytest.approx(expected_continuous_score) @@ -173,13 +201,20 @@ def test_continuous_score_consistency(): def test_check_correctness(): completion = prime_code_answers[0] ground_truth = json.loads(prime_code_gts[0]) - ground_truth_single = {"inputs": ground_truth["inputs"][:1], "outputs": ground_truth["outputs"][:1]} - res, meta = apps_check_correctness(in_outs=ground_truth_single, generation=completion, timeout=5, debug=False) + ground_truth_single = { + "inputs": ground_truth["inputs"][:1], + "outputs": ground_truth["outputs"][:1], + } + res, meta = apps_check_correctness( + in_outs=ground_truth_single, generation=completion, timeout=5, debug=False + ) print(res, meta) def test_prime_math(): data_source = "numina_aops_forum" - for completion, ground_truth in zip(prime_math_answers, prime_math_gts, strict=True): + for completion, ground_truth in zip( + prime_math_answers, prime_math_gts, strict=True + ): score = default_compute_score(data_source, completion, ground_truth) assert float(score) == 1.0 diff --git a/Agent0/executor_train/verl/tests/utils/test_activation_offload.py b/Agent0/executor_train/verl/tests/utils/test_activation_offload.py index 2393d79..9186614 100644 --- a/Agent0/executor_train/verl/tests/utils/test_activation_offload.py +++ b/Agent0/executor_train/verl/tests/utils/test_activation_offload.py @@ -26,10 +26,16 @@ from verl.utils.activation_offload import enable_activation_offloading from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, get_fsdp_wrap_policy +from verl.utils.fsdp_utils import ( + MixedPrecisionPolicy, + apply_fsdp2, + get_fsdp_wrap_policy, +) -def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy="fsdp"): +def _fsdp_activation_offloading_test( + rank, world_size, rendezvous_file, strategy="fsdp" +): torch.cuda.set_device(rank) torch.distributed.init_process_group( backend="nccl", @@ -37,19 +43,27 @@ def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy rank=rank, world_size=world_size, ) - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) + device_mesh = init_device_mesh( + "cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",) + ) model_name = "Qwen/Qwen2.5-0.5B-Instruct" config = Qwen2Config(num_hidden_layers=4) with torch.device("cuda"): model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + config=config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", ) model = model.to(device="cuda") # Wrap model with FSDP - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) if strategy == "fsdp": model = FSDP( @@ -63,7 +77,9 @@ def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy ) else: mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + cast_forward_inputs=True, ) fsdp_kwargs = { "mesh": device_mesh, @@ -103,7 +119,9 @@ def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy # Save checkpoint after first update temp_dir = tempfile.mkdtemp() checkpoint_path = os.path.join(temp_dir, "checkpoint") - checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0) + checkpoint_manager.save_checkpoint( + local_path=checkpoint_path, hdfs_path=None, global_step=0 + ) # Step 2: Second update and forward pass outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2) @@ -115,7 +133,9 @@ def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy # Record logits after second update with torch.no_grad(): - logits_without_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits + logits_without_offloading = model( + input_ids=input_ids2, attention_mask=attention_mask2 + ).logits # Step 3: wrap module with activation offloading and load checkpoint enable_activation_offloading(model, "fsdp") @@ -131,10 +151,14 @@ def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy # Record logits after loaded checkpoint and update with torch.no_grad(): - logits_with_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits + logits_with_offloading = model( + input_ids=input_ids2, attention_mask=attention_mask2 + ).logits # Step 4: Verify outputs match - torch.testing.assert_close(logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0) + torch.testing.assert_close( + logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0 + ) print(f"Activaiton offloading for {strategy} test passed on {world_size} GPUs!") # Cleanup diff --git a/Agent0/executor_train/verl/tests/utils/test_flops_counter.py b/Agent0/executor_train/verl/tests/utils/test_flops_counter.py index 0b8889b..a71a8d3 100644 --- a/Agent0/executor_train/verl/tests/utils/test_flops_counter.py +++ b/Agent0/executor_train/verl/tests/utils/test_flops_counter.py @@ -147,11 +147,15 @@ def test_flops_counter(config_type: str): config = Config(test_config["config"]) flops_counter = FlopsCounter(config) for batch_seqlens, expected_flops in zip( - test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"], strict=True + test_config["batch_seqlens_tuple"], + test_config["expected_flops_tuple"], + strict=True, ): # set delta time to 1 to get the flops counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1) - print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}") - assert math.isclose(counted_flops, expected_flops), ( + print( f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}" ) + assert math.isclose( + counted_flops, expected_flops + ), f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}" diff --git a/Agent0/executor_train/verl/tests/utils/test_fs_on_cpu.py b/Agent0/executor_train/verl/tests/utils/test_fs_on_cpu.py index 7ae85e0..7ffd7c8 100644 --- a/Agent0/executor_train/verl/tests/utils/test_fs_on_cpu.py +++ b/Agent0/executor_train/verl/tests/utils/test_fs_on_cpu.py @@ -58,7 +58,9 @@ def fake_copy(src: str, dst: str, *args, **kwargs): # Test initial copy local_path = fs.copy_to_local(hdfs_path, cache_dir=test_cache) - expected_path = os.path.join(test_cache, fs.md5_encode(hdfs_path), os.path.basename(hdfs_path)) + expected_path = os.path.join( + test_cache, fs.md5_encode(hdfs_path), os.path.basename(hdfs_path) + ) assert local_path == expected_path assert os.path.exists(local_path) diff --git a/Agent0/executor_train/verl/tests/utils/test_import_utils_on_cpu.py b/Agent0/executor_train/verl/tests/utils/test_import_utils_on_cpu.py index 59709b8..29feb17 100644 --- a/Agent0/executor_train/verl/tests/utils/test_import_utils_on_cpu.py +++ b/Agent0/executor_train/verl/tests/utils/test_import_utils_on_cpu.py @@ -84,7 +84,9 @@ def test_load_extern_type_invalid_module(): # Create a temporary file with syntax errors import tempfile - with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp_file: + with tempfile.NamedTemporaryFile( + suffix=".py", mode="w+", delete=False + ) as temp_file: temp_file.write("This is not valid Python syntax :") temp_path = temp_file.name diff --git a/Agent0/executor_train/verl/tests/utils/test_linear_cross_entropy.py b/Agent0/executor_train/verl/tests/utils/test_linear_cross_entropy.py index 0512d13..5867ed3 100644 --- a/Agent0/executor_train/verl/tests/utils/test_linear_cross_entropy.py +++ b/Agent0/executor_train/verl/tests/utils/test_linear_cross_entropy.py @@ -46,7 +46,11 @@ def run_torch_entropy( - hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none" + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, + reduction="none", ) -> list[torch.Tensor]: hidden = hidden.squeeze(0).to(torch.float32) weight = weight.transpose(0, 1).to(torch.float32) @@ -56,7 +60,9 @@ def run_torch_entropy( entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] entropy = entropy_a - entropy_b - logprobs = torch.nn.functional.cross_entropy(logits, labels.squeeze(0), reduction=reduction) # [num_tokens] + logprobs = torch.nn.functional.cross_entropy( + logits, labels.squeeze(0), reduction=reduction + ) # [num_tokens] logprobs = torch.neg(logprobs) return logprobs, entropy @@ -74,7 +80,9 @@ def run_verl_original_entropy( # compute entropy entropy = compute_entropy_from_logits(logits) # ((total_nnz / sp) + pad) # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - logprobs = logprobs_from_logits(logits=logits, labels=labels, inplace_backward=False) + logprobs = logprobs_from_logits( + logits=logits, labels=labels, inplace_backward=False + ) return logprobs, entropy @@ -144,21 +152,33 @@ def generate_hyper(self): def generate_forward_inputs(self): hidden = ( - torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda") + torch.empty( + (self.batch_size, self.num_tokens, self.hidden_size), + dtype=self.dtype, + device="cuda", + ) .uniform_(-0.5, 0.5) .requires_grad_() ) weight = ( - torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda") + torch.empty( + (self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda" + ) .uniform_(-0.5, 0.5) .requires_grad_() ) - labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda") + labels = torch.randint( + 0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda" + ) return hidden, weight, labels def generate_backward_inputs(self): - g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5) - g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1) + g_entropy = torch.empty( + (self.num_tokens,), dtype=self.dtype, device="cuda" + ).uniform_(-0.5, 0.5) + g_logprobs = torch.empty( + (self.num_tokens,), dtype=self.dtype, device="cuda" + ).uniform_(-1, 1) return g_entropy, g_logprobs def verify_correctness(self, iterations=5): @@ -182,13 +202,17 @@ def verify_correctness(self, iterations=5): hidden, weight, labels = self.generate_forward_inputs() start_event.record() - (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature) + (torch_logprobs, torch_entropy) = run_torch_entropy( + hidden, weight, labels, self.temperature + ) end_event.record() torch.cuda.synchronize() torch_forward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature) + (verl_logprobs, verl_entropy) = run_verl_original_entropy( + hidden, weight, labels, self.temperature + ) end_event.record() torch.cuda.synchronize() verl_forward_latency.append(start_event.elapsed_time(end_event)) @@ -202,32 +226,61 @@ def verify_correctness(self, iterations=5): verl_fused_forward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature) + (kernel_logprobs, kernel_entropy) = linear_cross_entropy( + hidden, weight, labels, self.temperature + ) end_event.record() torch.cuda.synchronize() kernel_forward_latency.append(start_event.elapsed_time(end_event)) - torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4) - torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4 + ) - torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) - torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) - torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) - torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4 + ) + torch.testing.assert_close( + verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4 + ) - torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) - torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) - torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) - torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) - torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) - torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) + torch.testing.assert_close( + torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4 + ) + torch.testing.assert_close( + torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4 + ) + torch.testing.assert_close( + verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4 + ) + torch.testing.assert_close( + verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4 + ) + torch.testing.assert_close( + verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4 + ) + torch.testing.assert_close( + verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4 + ) # backward g_entropy, g_logprobs = self.generate_backward_inputs() start_event.record() (d_torch_hidden, d_torch_weight) = torch.autograd.grad( - (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + (torch_entropy, torch_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False, ) end_event.record() torch.cuda.synchronize() @@ -235,7 +288,10 @@ def verify_correctness(self, iterations=5): start_event.record() (d_verl_hidden, d_verl_weight) = torch.autograd.grad( - (verl_entropy, verl_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + (verl_entropy, verl_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False, ) end_event.record() torch.cuda.synchronize() @@ -243,7 +299,10 @@ def verify_correctness(self, iterations=5): start_event.record() (d_verl_fused_hidden, d_verl_fused_weight) = torch.autograd.grad( - (verl_fused_entropy, verl_fused_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + (verl_fused_entropy, verl_fused_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False, ) end_event.record() torch.cuda.synchronize() @@ -251,28 +310,59 @@ def verify_correctness(self, iterations=5): start_event.record() (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad( - (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + (kernel_entropy, kernel_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False, ) end_event.record() torch.cuda.synchronize() kernel_backward_latency.append(start_event.elapsed_time(end_event)) - torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close( + d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4 + ) + torch.testing.assert_close( + d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4 + ) - torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close( + d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4 + ) + torch.testing.assert_close( + d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4 + ) + torch.testing.assert_close( + d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4 + ) + torch.testing.assert_close( + d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4 + ) + torch.testing.assert_close( + d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4 + ) + torch.testing.assert_close( + d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4 + ) - torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) + torch.testing.assert_close( + d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2 + ) + torch.testing.assert_close( + d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2 + ) + torch.testing.assert_close( + d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2 + ) + torch.testing.assert_close( + d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2 + ) + torch.testing.assert_close( + d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2 + ) + torch.testing.assert_close( + d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2 + ) # remove first latency torch_forward_latency = torch_forward_latency[1:] @@ -329,17 +419,24 @@ def check_storage(self, method_name, run_forward): (logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature) torch.cuda.synchronize() torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 - print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB") + print( + f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB" + ) g_entropy, g_logprobs = self.generate_backward_inputs() torch.cuda.reset_peak_memory_stats() (d_torch_hidden, d_torch_weight) = torch.autograd.grad( - (entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + (entropy, logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False, ) torch.cuda.synchronize() torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 - print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB") + print( + f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB" + ) def check_storage_all(self): self.check_storage("Torch", run_torch_entropy) diff --git a/Agent0/executor_train/verl/tests/utils/test_linear_cross_entropy_tp.py b/Agent0/executor_train/verl/tests/utils/test_linear_cross_entropy_tp.py index 9c1f868..eff9034 100644 --- a/Agent0/executor_train/verl/tests/utils/test_linear_cross_entropy_tp.py +++ b/Agent0/executor_train/verl/tests/utils/test_linear_cross_entropy_tp.py @@ -40,7 +40,11 @@ # FIXME: remove these manually included paths import sys - sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))) + sys.path.append( + os.path.abspath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") + ) + ) finally: from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy @@ -55,7 +59,11 @@ def run_torch_entropy( - hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none" + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, + reduction="none", ) -> list[torch.Tensor]: # [num_tokens, vocab_size] if len(hidden.shape) > 2: @@ -64,14 +72,20 @@ def run_torch_entropy( labels = labels.view(-1) logits = torch.matmul( hidden.to(torch.float32), - weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32), + ( + weight.to(torch.float32) + if weight.size(0) == hidden.size(1) + else weight.T.to(torch.float32) + ), ) logits /= temperature pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] entropy = entropy_a - entropy_b - logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens] + logprobs = torch.nn.functional.cross_entropy( + logits, labels, reduction=reduction + ) # [num_tokens] logprobs = torch.neg(logprobs) return logprobs, entropy @@ -98,10 +112,15 @@ def forward( if len(labels.shape) > 1: labels = labels.view(-1) - logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T) # [num_tokens, vocab_size] + logits = torch.matmul( + hidden.to(torch.float32), weight.to(torch.float32).T + ) # [num_tokens, vocab_size] logits /= temperature whole_logits = torch.empty( - (logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)), + ( + logits.shape[0], + logits.shape[1] * dist.get_world_size(dist_process_group), + ), dtype=logits.dtype, device=logits.device, ) @@ -116,7 +135,9 @@ def forward( entropy_b = torch.sum(pd * whole_logits, dim=-1) # [num_tokens] entropy = entropy_a - entropy_b - logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none") + logprobs = torch.nn.functional.cross_entropy( + whole_logits, labels, reduction="none" + ) logprobs = torch.neg(logprobs) ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b) @@ -148,7 +169,9 @@ def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor): # d_entropy/d_logits = d_entropy_a - d_entropy_b # d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1) # d_entropy/d_logits = -pd * (logits - b.unsqueeze(1)) - d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1))) + d_logits_entropy = g_entropy.unsqueeze(1) * ( + -pd * (whole_logits - entropy_b.unsqueeze(1)) + ) # Gradient for logprobs # logprobs = -cross_entropy = -log(pd[labels]) @@ -241,21 +264,33 @@ def generate_hyper(self): def generate_forward_inputs(self): hidden = ( - torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda") + torch.empty( + (self.batch_size, self.num_tokens, self.hidden_size), + dtype=self.dtype, + device="cuda", + ) .uniform_(-0.5, 0.5) .requires_grad_() ) weight = ( - torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda") + torch.empty( + (self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda" + ) .uniform_(-0.5, 0.5) .requires_grad_() ) - labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda") + labels = torch.randint( + 0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda" + ) return hidden, weight, labels def generate_backward_inputs(self): - g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5) - g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1) + g_entropy = torch.empty( + (self.num_tokens,), dtype=self.dtype, device="cuda" + ).uniform_(-0.5, 0.5) + g_logprobs = torch.empty( + (self.num_tokens,), dtype=self.dtype, device="cuda" + ).uniform_(-1, 1) return g_entropy, g_logprobs def verify_torch_itself(self, iterations: int = 5): @@ -276,12 +311,15 @@ def verify_torch_itself(self, iterations: int = 5): # Create a single contiguous tensor to hold all gathered weights whole_weight = torch.empty( - (self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device + (self.vocab_size * self.world_size, self.hidden_size), + dtype=weight.dtype, + device=weight.device, ) # Create views into the tensor for each rank's portion whole_weight_views = [ - whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size) + whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] + for i in range(self.world_size) ] # Perform all_gather operation using the views @@ -290,11 +328,17 @@ def verify_torch_itself(self, iterations: int = 5): # Set requires_grad for autograd whole_weight.requires_grad_() - (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature) + (single_logprobs, single_entropy) = run_torch_entropy( + hidden, whole_weight, labels, self.temperature + ) - (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + (tp_logprobs, tp_entropy) = run_torch_entropy_tp( + hidden, weight, labels, self.temperature, self.group + ) - torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close( + single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4 + ) torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4) # backward pass @@ -304,22 +348,34 @@ def verify_torch_itself(self, iterations: int = 5): dist.broadcast(g_logprobs, src=0, group=self.group) (single_d_hidden, single_d_weight) = torch.autograd.grad( - (single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False + (single_entropy, single_logprobs), + (hidden, whole_weight), + (g_entropy, g_logprobs), + retain_graph=False, ) (tp_d_hidden, tp_d_weight) = torch.autograd.grad( - (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + (tp_entropy, tp_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False, ) # NOTE: all-reduce on hidden is conducted outside the kernel dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group) - torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close( + tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4 + ) # Extract the corresponding slice from single_d_weight for comparison # tp_d_weight has shape [vocab_size, hidden_size] # single_d_weight has shape [vocab_size * world_size, hidden_size] torch.testing.assert_close( tp_d_weight, - single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size], + single_d_weight[ + self.local_rank + * self.vocab_size : (self.local_rank + 1) + * self.vocab_size + ], atol=1e-2, rtol=1e-4, ) @@ -339,7 +395,9 @@ def check_torch_storage(self): dist.broadcast(labels, src=0, group=self.group) torch.cuda.reset_peak_memory_stats() - (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + (tp_logprobs, tp_entropy) = run_torch_entropy_tp( + hidden, weight, labels, self.temperature, self.group + ) torch.cuda.synchronize() forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 @@ -350,7 +408,10 @@ def check_torch_storage(self): torch.cuda.reset_peak_memory_stats() (d_tp_hidden, d_tp_weight) = torch.autograd.grad( - (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + (tp_entropy, tp_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False, ) torch.cuda.synchronize() backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 @@ -358,8 +419,12 @@ def check_torch_storage(self): dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group) if self.local_rank == 0: - print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB") - print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB") + print( + f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB" + ) + print( + f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB" + ) def verify_kernel_correctness(self, iterations: int = 5): self.cleanup() @@ -381,7 +446,9 @@ def verify_kernel_correctness(self, iterations: int = 5): dist.broadcast(labels, src=0, group=self.group) start_event.record() - (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) + (torch_logprobs, torch_entropy) = run_torch_entropy_tp( + hidden, weight, labels, self.temperature, self.group + ) end_event.record() torch.cuda.synchronize() torch_forward_latency.append(start_event.elapsed_time(end_event)) @@ -394,8 +461,12 @@ def verify_kernel_correctness(self, iterations: int = 5): torch.cuda.synchronize() kernel_forward_latency.append(start_event.elapsed_time(end_event)) - torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2) - torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2) + torch.testing.assert_close( + torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2 + ) + torch.testing.assert_close( + torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2 + ) # backward pass g_entropy, g_logprobs = self.generate_backward_inputs() @@ -405,7 +476,10 @@ def verify_kernel_correctness(self, iterations: int = 5): start_event.record() (torch_d_hidden, torch_d_weight) = torch.autograd.grad( - (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + (torch_entropy, torch_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False, ) end_event.record() torch.cuda.synchronize() @@ -415,7 +489,10 @@ def verify_kernel_correctness(self, iterations: int = 5): start_event.record() (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad( - (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + (kernel_entropy, kernel_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False, ) end_event.record() torch.cuda.synchronize() @@ -423,8 +500,12 @@ def verify_kernel_correctness(self, iterations: int = 5): # NOTE: all-reduce on hidden is conducted outside the kernel dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group) - torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2) + torch.testing.assert_close( + torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2 + ) + torch.testing.assert_close( + torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2 + ) # remove first latency torch_forward_latency = torch_forward_latency[1:] @@ -476,7 +557,10 @@ def check_kernel_storage(self): torch.cuda.reset_peak_memory_stats() (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad( - (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False + (kernel_entropy, kernel_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False, ) torch.cuda.synchronize() kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 @@ -484,8 +568,12 @@ def check_kernel_storage(self): dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group) if self.local_rank == 0: - print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") - print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") + print( + f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB" + ) + print( + f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB" + ) if __name__ == "__main__": diff --git a/Agent0/executor_train/verl/tests/utils/test_model_on_cpu.py b/Agent0/executor_train/verl/tests/utils/test_model_on_cpu.py index 8b1416c..2d1c32c 100644 --- a/Agent0/executor_train/verl/tests/utils/test_model_on_cpu.py +++ b/Agent0/executor_train/verl/tests/utils/test_model_on_cpu.py @@ -24,7 +24,10 @@ "override_kwargs", [ {"param_a": 5, "new_param": "plain_added"}, - {"param_a": 2, "nested_params": {"sub_param_x": "updated_x", "sub_param_z": True}}, + { + "param_a": 2, + "nested_params": {"sub_param_x": "updated_x", "sub_param_z": True}, + }, ], ) def test_update_model_config(override_kwargs): @@ -34,7 +37,9 @@ def test_update_model_config(override_kwargs): """ # Create a fresh mock config object for each test case mock_config = SimpleNamespace( - param_a=1, nested_params=SimpleNamespace(sub_param_x="original_x", sub_param_y=100), other_param="keep_me" + param_a=1, + nested_params=SimpleNamespace(sub_param_x="original_x", sub_param_y=100), + other_param="keep_me", ) # Apply the updates using the parametrized override_kwargs update_model_config(mock_config, override_kwargs) @@ -42,11 +47,25 @@ def test_update_model_config(override_kwargs): # Assertions to check if the config was updated correctly if "nested_params" in override_kwargs: # Case 2: Nested override override_nested = override_kwargs["nested_params"] - assert mock_config.nested_params.sub_param_x == override_nested["sub_param_x"], "Nested sub_param_x mismatch" - assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged" - assert hasattr(mock_config.nested_params, "sub_param_z"), "Expected nested sub_param_z to be added" - assert mock_config.nested_params.sub_param_z == override_nested["sub_param_z"], "Value of sub_param_z mismatch" + assert ( + mock_config.nested_params.sub_param_x == override_nested["sub_param_x"] + ), "Nested sub_param_x mismatch" + assert ( + mock_config.nested_params.sub_param_y == 100 + ), "Nested sub_param_y should be unchanged" + assert hasattr( + mock_config.nested_params, "sub_param_z" + ), "Expected nested sub_param_z to be added" + assert ( + mock_config.nested_params.sub_param_z == override_nested["sub_param_z"] + ), "Value of sub_param_z mismatch" else: # Case 1: Plain override (nested params untouched) - assert mock_config.nested_params.sub_param_x == "original_x", "Nested sub_param_x should be unchanged" - assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged" - assert not hasattr(mock_config.nested_params, "sub_param_z"), "Nested sub_param_z should not exist" + assert ( + mock_config.nested_params.sub_param_x == "original_x" + ), "Nested sub_param_x should be unchanged" + assert ( + mock_config.nested_params.sub_param_y == 100 + ), "Nested sub_param_y should be unchanged" + assert not hasattr( + mock_config.nested_params, "sub_param_z" + ), "Nested sub_param_z should not exist" diff --git a/Agent0/executor_train/verl/tests/utils/test_nvtx_profile.py b/Agent0/executor_train/verl/tests/utils/test_nvtx_profile.py index 3450260..938d58f 100644 --- a/Agent0/executor_train/verl/tests/utils/test_nvtx_profile.py +++ b/Agent0/executor_train/verl/tests/utils/test_nvtx_profile.py @@ -42,8 +42,12 @@ def test_config_init(self): assert isinstance(profiler_config, ProfilerConfig) with self.assertRaises(AttributeError): _ = profiler_config.non_existing_key - assert config.get("non_existing_key") == profiler_config.get("non_existing_key") - assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1) + assert config.get("non_existing_key") == profiler_config.get( + "non_existing_key" + ) + assert config.get("non_existing_key", 1) == profiler_config.get( + "non_existing_key", 1 + ) assert config["discrete"] == profiler_config["discrete"] from dataclasses import FrozenInstanceError @@ -73,7 +77,10 @@ def test_initialization(self): self.assertEqual(self.profiler.discrete, False) def test_start_stop_profiling(self): - with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop: + with ( + patch("torch.cuda.profiler.start") as mock_start, + patch("torch.cuda.profiler.stop") as mock_stop, + ): # Test start self.profiler.start() self.assertTrue(self.profiler.this_step) @@ -88,7 +95,10 @@ def test_discrete_profiling(self): discrete_config = ProfilerConfig(discrete=True, all_ranks=True) profiler = NsightSystemsProfiler(self.rank, discrete_config) - with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop: + with ( + patch("torch.cuda.profiler.start") as mock_start, + patch("torch.cuda.profiler.stop") as mock_stop, + ): profiler.start() self.assertTrue(profiler.this_step) mock_start.assert_not_called() # Shouldn't start immediately in discrete mode @@ -109,7 +119,9 @@ def test_func(self, *args, **kwargs): with ( patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop, - patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range, + patch( + "verl.utils.profiler.nvtx_profile.mark_start_range" + ) as mock_start_range, patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range, ): result = test_func(mock_self) @@ -133,7 +145,9 @@ def test_func(self, *args, **kwargs): with ( patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop, - patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range, + patch( + "verl.utils.profiler.nvtx_profile.mark_start_range" + ) as mock_start_range, patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range, ): result = test_func(mock_self) diff --git a/Agent0/executor_train/verl/tests/utils/test_rollout_trace_on_cpu.py b/Agent0/executor_train/verl/tests/utils/test_rollout_trace_on_cpu.py index e9358c1..d4344ed 100644 --- a/Agent0/executor_train/verl/tests/utils/test_rollout_trace_on_cpu.py +++ b/Agent0/executor_train/verl/tests/utils/test_rollout_trace_on_cpu.py @@ -18,7 +18,11 @@ import pytest -from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op +from verl.utils.rollout_trace import ( + RolloutTraceConfig, + rollout_trace_attr, + rollout_trace_op, +) @pytest.fixture(autouse=True) @@ -39,7 +43,10 @@ def mock_weave_client(): # Also mock the call_context if it's used internally by the decorator mock_weave.trace.context.call_context.return_value = MagicMock() - with patch.dict(sys.modules, {"weave": mock_weave, "weave.trace.context": mock_weave.trace.context}): + with patch.dict( + sys.modules, + {"weave": mock_weave, "weave.trace.context": mock_weave.trace.context}, + ): yield mock_client @@ -78,7 +85,9 @@ async def test_rollout_trace_on_untraced_class(): async def test_rollout_trace_with_tracer(mock_weave_client): """Tests that the decorator calls the tracer's methods correctly.""" - RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") + RolloutTraceConfig.init( + project_name="my-project", experiment_name="my-experiment", backend="weave" + ) instance = TracedClass() assert RolloutTraceConfig.get_client() is mock_weave_client @@ -97,7 +106,9 @@ async def test_rollout_trace_with_tracer(mock_weave_client): async def test_rollout_trace_with_exception(mock_weave_client): """Tests that `finish` is called with the exception when one is raised.""" - RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") + RolloutTraceConfig.init( + project_name="my-project", experiment_name="my-experiment", backend="weave" + ) instance = TracedClass() with pytest.raises(ValueError, match="Test Exception"): @@ -116,7 +127,9 @@ async def test_rollout_trace_with_exception(mock_weave_client): async def test_rollout_trace_with_dummy_backend(mock_weave_client): """Tests that the tracer is not called when the backend is 'dummy'.""" - RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="dummy") + RolloutTraceConfig.init( + project_name="my-project", experiment_name="my-experiment", backend="dummy" + ) instance = TracedClass() await instance.my_method("test_a") @@ -132,7 +145,9 @@ async def test_rollout_trace_with_real_weave_backend(): """Integration test with a real weave backend.""" # This assumes that the weave environment (e.g., project) is configured - RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") + RolloutTraceConfig.init( + project_name="my-project", experiment_name="my-experiment", backend="weave" + ) instance = TracedClass() @@ -142,4 +157,6 @@ async def test_rollout_trace_with_real_weave_backend(): with pytest.raises(ValueError, match="Test Exception"): await instance.my_method_with_exception() - print("\nWeave integration test ran successfully. Check your weave project for the trace.") + print( + "\nWeave integration test ran successfully. Check your weave project for the trace." + ) diff --git a/Agent0/executor_train/verl/tests/utils/test_seqlen_balancing.py b/Agent0/executor_train/verl/tests/utils/test_seqlen_balancing.py index df7760b..31bc719 100644 --- a/Agent0/executor_train/verl/tests/utils/test_seqlen_balancing.py +++ b/Agent0/executor_train/verl/tests/utils/test_seqlen_balancing.py @@ -18,18 +18,27 @@ from verl import DataProto from verl.utils.model import create_random_mask -from verl.utils.seqlen_balancing import ceildiv, get_reverse_idx, rearrange_micro_batches +from verl.utils.seqlen_balancing import ( + ceildiv, + get_reverse_idx, + rearrange_micro_batches, +) def test_seqlen_balancing(): input_ids = torch.randint(low=0, high=10, size=(20, 100)) attention_mask = create_random_mask( - input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5 + input_ids=input_ids, + max_ratio_of_left_padding=0.1, + max_ratio_of_valid_token=0.9, + min_ratio_of_valid_token=0.5, ) data = {"input_ids": input_ids, "attention_mask": attention_mask} dataproto = DataProto.from_single_dict(data) - micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300) + micro_batches, micro_bsz_idx_lst = rearrange_micro_batches( + dataproto.batch, max_token_len=300 + ) batch = torch.cat(micro_batches) micro_bsz_idx = [] for idx in micro_bsz_idx_lst: diff --git a/Agent0/executor_train/verl/tests/utils/test_timeout_decorator_cpu.py b/Agent0/executor_train/verl/tests/utils/test_timeout_decorator_cpu.py index 3417469..ce90969 100644 --- a/Agent0/executor_train/verl/tests/utils/test_timeout_decorator_cpu.py +++ b/Agent0/executor_train/verl/tests/utils/test_timeout_decorator_cpu.py @@ -107,13 +107,17 @@ def test_slow_task_timeout(): # Renamed from test_multiprocessing_slow_task_tim with pytest.raises(TimeoutError) as excinfo: # Use pytest.raises slow_task(1) # Check the error message from the multiprocessing implementation - assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value) # Use pytest assert + assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str( + excinfo.value + ) # Use pytest assert def test_internal_exception(): # Renamed from test_multiprocessing_internal_exception """Tests timeout correctly propagates internal exceptions.""" # Apply the default timeout decorator dynamically to the undecorated function - decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS)(task_raises_value_error) # Apply decorator dynamically + decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS)( + task_raises_value_error + ) # Apply decorator dynamically with pytest.raises(ValueError) as excinfo: # Use pytest.raises decorated_task() # Call the dynamically decorated function assert str(excinfo.value) == "Specific value error from task" # Use pytest assert @@ -132,7 +136,9 @@ def plain_quick_task_logic(): time.sleep(0.1) return "quick_ok_signal" - decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_quick_task_logic) + decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)( + plain_quick_task_logic + ) assert decorated_task() == "quick_ok_signal" # Use pytest assert @@ -144,14 +150,20 @@ def plain_slow_task_logic(): time.sleep(LONG_TASK_DURATION) return "slow_finished_signal" - decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_slow_task_logic) + decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)( + plain_slow_task_logic + ) with pytest.raises(TimeoutError) as excinfo: # Use pytest.raises decorated_task() # Check the error message (falls back to multiprocessing message on POSIX) - assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value) # Use pytest assert + assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str( + excinfo.value + ) # Use pytest assert -@pytest.mark.skip(reason="this test won't pass. Just to show why use_signals should not be used") +@pytest.mark.skip( + reason="this test won't pass. Just to show why use_signals should not be used" +) def test_signal_in_thread_does_not_timeout(): """ Tests that signal-based timeout does NOT work reliably in a child thread. diff --git a/Agent0/executor_train/verl/tests/utils/test_torch_functional.py b/Agent0/executor_train/verl/tests/utils/test_torch_functional.py index 900cb5d..5ff2164 100644 --- a/Agent0/executor_train/verl/tests/utils/test_torch_functional.py +++ b/Agent0/executor_train/verl/tests/utils/test_torch_functional.py @@ -19,7 +19,11 @@ import torch.distributed as dist import torch.multiprocessing as mp -from verl.utils.torch_functional import distributed_masked_mean, distributed_mean_max_min_std, masked_mean +from verl.utils.torch_functional import ( + distributed_masked_mean, + distributed_mean_max_min_std, + masked_mean, +) def _worker_mean(rank: int, world_size: int, rendezvous_file: str): @@ -99,7 +103,9 @@ def _worker_mask(rank: int, world_size: int, rendezvous_file: str): valid_values = [1.0] + [2 * i + 2.0 for i in range(1, world_size)] expected_mean = sum(valid_values) / len(valid_values) - assert torch.allclose(gmean.cpu(), torch.tensor(expected_mean)), f"masked_mean@{rank}" + assert torch.allclose( + gmean.cpu(), torch.tensor(expected_mean) + ), f"masked_mean@{rank}" dist.destroy_process_group() diff --git a/Agent0/executor_train/verl/tests/workers/reward_manager/test_registry_on_cpu.py b/Agent0/executor_train/verl/tests/workers/reward_manager/test_registry_on_cpu.py index 9932ae8..7103fa9 100644 --- a/Agent0/executor_train/verl/tests/workers/reward_manager/test_registry_on_cpu.py +++ b/Agent0/executor_train/verl/tests/workers/reward_manager/test_registry_on_cpu.py @@ -15,14 +15,20 @@ import pytest # Assuming REWARD_MANAGER_REGISTRY is defined somewhere in the module -from verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY, get_reward_manager_cls, register +from verl.workers.reward_manager.registry import ( + REWARD_MANAGER_REGISTRY, + get_reward_manager_cls, + register, +) @pytest.fixture def setup(): """Setup test cases with a mock registry.""" REWARD_MANAGER_REGISTRY.clear() - REWARD_MANAGER_REGISTRY.update({"manager1": "Manager1Class", "manager2": "Manager2Class"}) + REWARD_MANAGER_REGISTRY.update( + {"manager1": "Manager1Class", "manager2": "Manager2Class"} + ) return REWARD_MANAGER_REGISTRY diff --git a/Agent0/executor_train/verl/tests/workers/rollout/async_rollout_utils.py b/Agent0/executor_train/verl/tests/workers/rollout/async_rollout_utils.py index 22f2029..fdf34df 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/async_rollout_utils.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/async_rollout_utils.py @@ -34,21 +34,29 @@ def init_async_rollout_manager(config: DictConfig) -> AsyncLLMServerManager: mapping = { Role.ActorRollout: global_pool_id, } - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager = ResourcePoolManager( + resource_pool_spec=resource_pool_spec, mapping=mapping + ) resource_pool_manager.create_resource_pool() - resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} + resource_pool_to_cls = { + pool: {} for pool in resource_pool_manager.resource_pool_dict.values() + } # create actor and rollout resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) actor_rollout_cls = RayClassWithInitArgs( - cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout" + cls=role_worker_mapping[Role.ActorRollout], + config=config.actor_rollout_ref, + role="actor_rollout", ) resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls all_wg = {} for resource_pool, class_dict in resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) + wg_dict = RayWorkerGroup( + resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls + ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) actor_rollout_wg = all_wg["actor_rollout"] diff --git a/Agent0/executor_train/verl/tests/workers/rollout/perf/vllm_async_rollout.py b/Agent0/executor_train/verl/tests/workers/rollout/perf/vllm_async_rollout.py index dbcd255..1ba4a87 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/perf/vllm_async_rollout.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/perf/vllm_async_rollout.py @@ -38,7 +38,11 @@ from torch.utils.data import SequentialSampler from torchdata.stateful_dataloader import StatefulDataLoader -from tests.experimental.agent_loop.agent_utils import AgentLoopManager, RayWorkerGroup, init_agent_loop_manager +from tests.experimental.agent_loop.agent_utils import ( + AgentLoopManager, + RayWorkerGroup, + init_agent_loop_manager, +) from verl.protocol import DataProto from verl.utils import hf_tokenizer from verl.utils.dataset import RLHFDataset @@ -71,7 +75,9 @@ def init_config(n_gpus_per_node) -> DictConfig: return config -def initialize(config, backend) -> tuple[AgentLoopManager | RayWorkerGroup, StatefulDataLoader]: +def initialize( + config, backend +) -> tuple[AgentLoopManager | RayWorkerGroup, StatefulDataLoader]: env_vars = { "NCCL_DEBUG": "WARN", "VLLM_USE_V1": "1", @@ -132,4 +138,9 @@ def perf_rollout(mode, backend, n_gpus_per_node, num_steps): # test_cases = [("sync", "sync"), ("async", "zeromq"), ("async", "ray")] test_cases = [("async", "zeromq"), ("async", "ray")] for mode, backend in test_cases: - perf_rollout(mode=mode, backend=backend, n_gpus_per_node=n_gpus_per_node, num_steps=num_steps) + perf_rollout( + mode=mode, + backend=backend, + n_gpus_per_node=n_gpus_per_node, + num_steps=num_steps, + ) diff --git a/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py b/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py index 6922389..d4ed947 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py @@ -19,7 +19,11 @@ import torch.distributed as dist from torch.distributed.fsdp import CPUOffload, MixedPrecision from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType +from torch.distributed.fsdp.api import ( + ShardedStateDictConfig, + ShardingStrategy, + StateDictType, +) from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from vllm import SamplingParams @@ -39,9 +43,13 @@ def main(): local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True) - actor_model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=True) + actor_model_config = AutoConfig.from_pretrained( + local_model_path, trust_remote_code=True + ) with torch.device("cuda"): - actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True) + actor_model = AutoModelForCausalLM.from_pretrained( + local_model_path, trust_remote_code=True + ) actor_model.to(torch.bfloat16) max_prompt_length = 16 @@ -57,8 +65,12 @@ def main(): attention_mask = prompts["attention_mask"] from verl.utils.torch_functional import pad_sequence_to_length - input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True).cuda() - attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True).cuda() + input_ids = pad_sequence_to_length( + input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True + ).cuda() + attention_mask = pad_sequence_to_length( + attention_mask, max_prompt_length, 0, left_pad=True + ).cuda() from transformers import GenerationConfig @@ -85,9 +97,15 @@ def main(): tensor_model_parallel_size = 4 from torch.distributed.device_mesh import init_device_mesh - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + device_mesh = init_device_mesh( + "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + ) - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) fsdp_model = FSDP( actor_model, use_orig_params=True, @@ -101,13 +119,21 @@ def main(): ) FSDP.set_state_dict_type( - fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() + fsdp_model, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), ) state_dict = fsdp_model.state_dict() sampling_params = SamplingParams( - temperature=0, top_p=1, n=1, max_tokens=response_length, logprobs=1, ignore_eos=True, detokenize=False + temperature=0, + top_p=1, + n=1, + max_tokens=response_length, + logprobs=1, + ignore_eos=True, + detokenize=False, ) print(actor_model_config) @@ -145,13 +171,19 @@ def main(): idx_list = [] batch_size = input_ids.shape[0] - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + pad_token_id = ( + tokenizer.pad_token_id + if tokenizer.pad_token_id is not None + else tokenizer.eos_token_id + ) from verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import _pre_process_inputs for i in range(batch_size): idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) print("start generation") - outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False) + outputs = llm.generate( + prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False + ) vllm_output = outputs[0].cuda() if torch.distributed.get_rank() == 0: print(f"hf response: {tokenizer.batch_decode(response)}") diff --git a/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py b/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py index 93aca6a..5bf5d92 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py @@ -75,7 +75,12 @@ def test_vllm_async_rollout_without_tool_calls(init_config): "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.", } ], - [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], + [ + { + "role": "user", + "content": "Let's play a role playing game. Your name is Bob, your favorite color is red.", + } + ], ] batch = DataProto( non_tensor_batch={ @@ -120,7 +125,9 @@ def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: schema = get_json_schema(self.get_current_temperature) return OpenAIFunctionToolSchema(**schema) - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: try: result = self.get_current_temperature(**parameters) return json.dumps(result), 0, {} @@ -151,7 +158,9 @@ def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): "unit": unit, } - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: try: result = self.get_temperature_date(**parameters) return json.dumps(result), 0, {} @@ -205,12 +214,17 @@ def test_vllm_async_rollout_with_tool_calls(init_config): "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" "Current Date: 2024-09-30", }, - {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, + { + "role": "user", + "content": "What's the temperature in San Francisco now? How about tomorrow?", + }, ], ] batch = DataProto( non_tensor_batch={ - "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + "raw_prompt": np.array( + [np.array(prompt) for prompt in raw_prompts], dtype=object + ), }, ) result = async_rollout_manager.generate_sequences(prompts=batch) @@ -228,14 +242,20 @@ def test_vllm_async_rollout_with_tool_calls(init_config): tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) responses = result.batch["responses"] response_mask = result.batch["response_mask"] - assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + assert ( + responses.size() == response_mask.size() + ), f"{responses.size()} != {response_mask.size()}" # Decode responses with response_mask for i in range(len(responses)): valid_tokens = responses[i][response_mask[i].bool()] response_str = tokenizer.decode(valid_tokens) - assert "" not in response_str, f"found in response: {response_str}" - assert "" not in response_str, f"found in response: {response_str}" + assert ( + "" not in response_str + ), f"found in response: {response_str}" + assert ( + "" not in response_str + ), f"found in response: {response_str}" print(f"response: {response_str}") print("Test passed!") diff --git a/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py b/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py index 30c9ae2..8e03a0b 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py @@ -61,12 +61,16 @@ def test_vllm_rollout_with_yarn_position_embeddings(): } ) - tokenizer = AutoTokenizer.from_pretrained(config.model_path, trust_remote_code=True, padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + config.model_path, trust_remote_code=True, padding_side="left" + ) tokenizer.pad_token = tokenizer.eos_token model_hf_config = AutoConfig.from_pretrained(config.model_path) # do_sample=False for temperate=0 deterministic - input_dataproto = prepare_input_dataproto(tokenizer, config, validate=True, do_sample=False) + input_dataproto = prepare_input_dataproto( + tokenizer, config, validate=True, do_sample=False + ) vllm_rollout = vLLMRollout( model_path=config.model_path, @@ -80,11 +84,15 @@ def test_vllm_rollout_with_yarn_position_embeddings(): ) if rank == 0: print("VLLM Rollout Outputs:") - print(tokenizer.batch_decode(rollout_response.batch["responses"][:], skip_special_tokens=False)) - for response in rollout_response.batch["responses"]: - assert "<|im_end|>" in tokenizer.decode(response, skip_special_tokens=False), ( - "Response should contain <|im_end|> token" + print( + tokenizer.batch_decode( + rollout_response.batch["responses"][:], skip_special_tokens=False ) + ) + for response in rollout_response.batch["responses"]: + assert "<|im_end|>" in tokenizer.decode( + response, skip_special_tokens=False + ), "Response should contain <|im_end|> token" print("Checks passed.") del vllm_rollout @@ -99,15 +107,27 @@ def prepare_input_dataproto(tokenizer, config, validate, do_sample=False): base_phrase = "Roses are red, sky is blue. " * 4096 preencode_prompts = [ # 32810 tokens > 32768 tokens - [{"role": "user", "content": base_phrase + "Who won the Champions League in 2019?"}], + [ + { + "role": "user", + "content": base_phrase + "Who won the Champions League in 2019?", + } + ], [{"role": "user", "content": base_phrase + "The founder of Apple is"}], [{"role": "user", "content": base_phrase + "What's your name"}], ] formatted_prompts = [ - tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) + tokenizer.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=True + ) for conversation in preencode_prompts ] - prompts = tokenizer(formatted_prompts, return_tensors="pt", padding="max_length", max_length=config.prompt_length) + prompts = tokenizer( + formatted_prompts, + return_tensors="pt", + padding="max_length", + max_length=config.prompt_length, + ) input_dataproto = DataProto.from_dict( { "input_ids": prompts["input_ids"], diff --git a/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py b/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py index c2b8f51..50643fc 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py @@ -18,7 +18,11 @@ import torch from torch.distributed.fsdp import CPUOffload, MixedPrecision from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType +from torch.distributed.fsdp.api import ( + ShardedStateDictConfig, + ShardingStrategy, + StateDictType, +) from transformers import AutoModelForCausalLM, AutoTokenizer from vllm import LLM, SamplingParams @@ -70,7 +74,9 @@ def are_lists_similar(a, b): @pytest.mark.skip("https://github.com/vllm-project/vllm/issues/16993") def test_vllm_spmd(): - assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." + assert ( + torch.cuda.device_count() >= 2 + ), "At least 2 GPUs is required to run tp+dp tests." local_rank, rank, world_size = initialize_global_process_group() # Initialize model and token @@ -80,9 +86,13 @@ def test_vllm_spmd(): from verl.utils.fs import copy_to_local local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + local_model_path, padding_side="left", trust_remote_code=True + ) - actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True) + actor_model = AutoModelForCausalLM.from_pretrained( + local_model_path, trust_remote_code=True + ) actor_model.to(torch.bfloat16) # fill rollout config @@ -98,8 +108,12 @@ def test_vllm_spmd(): input_ids = prompts["input_ids"] attention_mask = prompts["attention_mask"] - input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) - attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) + input_ids = pad_sequence_to_length( + input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True + ) + attention_mask = pad_sequence_to_length( + attention_mask, max_prompt_length, 0, left_pad=True + ) print("start generation") input_ids = input_ids.cuda() @@ -108,16 +122,27 @@ def test_vllm_spmd(): temperature = 0 top_p = 1 kwargs = dict( - n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True + n=1, + temperature=temperature, + top_p=top_p, + max_tokens=max_response_length, + logprobs=1, + ignore_eos=True, ) tensor_parallel_size = 4 from torch.distributed.device_mesh import init_device_mesh - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + device_mesh = init_device_mesh( + "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + ) - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) fsdp_model = FSDP( actor_model, @@ -132,7 +157,9 @@ def test_vllm_spmd(): ) FSDP.set_state_dict_type( - fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() + fsdp_model, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), ) state_dict = fsdp_model.state_dict() @@ -153,7 +180,9 @@ def test_vllm_spmd(): seed=1, ) - outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False) + outputs = llm.generate( + preencode_prompts, sampling_params=sampling_params, use_tqdm=False + ) vllm_response_tokens = [] for output in outputs: generated_text = output.outputs[0].text @@ -162,10 +191,15 @@ def test_vllm_spmd(): world_size = torch.distributed.get_world_size() model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model model.load_weights( - ((name, param.full_tensor() if world_size != 1 else param) for name, param in state_dict.items()) + ( + (name, param.full_tensor() if world_size != 1 else param) + for name, param in state_dict.items() + ) ) - outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False) + outputs = llm.generate( + preencode_prompts, sampling_params=sampling_params, use_tqdm=False + ) verl_vllm_response_tokens = [] for output in outputs: generated_text = output.outputs[0].text @@ -174,7 +208,9 @@ def test_vllm_spmd(): if torch.distributed.get_rank() == 0: print(f"vllm response: {vllm_response_tokens}") print(f"verl-vllm response: {verl_vllm_response_tokens}") - assert are_lists_similar(vllm_response_tokens, verl_vllm_response_tokens), "Strings differ more than 10%:\n" + assert are_lists_similar( + vllm_response_tokens, verl_vllm_response_tokens + ), "Strings differ more than 10%:\n" print("Check Pass") torch.distributed.destroy_process_group() diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_async_sglang_server.py b/Agent0/executor_train/verl/tests/workers/rollout/test_async_sglang_server.py index 0b4e914..3d3a8b6 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_async_sglang_server.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_async_sglang_server.py @@ -21,7 +21,9 @@ @patch.dict( "sys.modules", { - "verl.workers.rollout.sglang_rollout.sglang_rollout": MagicMock(SGLangRollout=MagicMock()), + "verl.workers.rollout.sglang_rollout.sglang_rollout": MagicMock( + SGLangRollout=MagicMock() + ), }, ) class TestAsyncSglangServer: @@ -30,10 +32,19 @@ def server_config(self): return DictConfig({"rollout": {"tensor_model_parallel_size": 2}}) @pytest.mark.asyncio - @patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.util.list_named_actors") - @patch("verl.workers.rollout.async_server.AsyncServerBase._start_fastapi_server", new_callable=AsyncMock) - @pytest.mark.filterwarnings("ignore:Ray state API is no longer experimental:DeprecationWarning") - async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, server_config): + @patch( + "verl.workers.rollout.sglang_rollout.async_sglang_server.ray.util.list_named_actors" + ) + @patch( + "verl.workers.rollout.async_server.AsyncServerBase._start_fastapi_server", + new_callable=AsyncMock, + ) + @pytest.mark.filterwarnings( + "ignore:Ray state API is no longer experimental:DeprecationWarning" + ) + async def test_init_engine( + self, mock_start_fastapi_server, mock_list_actors, server_config + ): mock_list_actors.return_value = [ {"name": "test_prefixWorkerDict_1:0", "namespace": "test"}, {"name": "test_prefixWorkerDict_1:1", "namespace": "test"}, @@ -44,7 +55,9 @@ async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, se {"name": "test_prefixWorkerDict_0:2", "namespace": "test"}, {"name": "test_prefixWorkerDict_0:3", "namespace": "test"}, ] - from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer + from verl.workers.rollout.sglang_rollout.async_sglang_server import ( + AsyncSglangServer, + ) ActualClassToInstantiate = AsyncSglangServer if hasattr(AsyncSglangServer, "__ray_metadata__") and hasattr( diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_custom_completion_callback.py b/Agent0/executor_train/verl/tests/workers/rollout/test_custom_completion_callback.py index 495bce9..d3767b9 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_custom_completion_callback.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_custom_completion_callback.py @@ -35,7 +35,10 @@ from verl.protocol import DataProto from verl.utils import hf_tokenizer from verl.utils.reward_score.sandbox_fusion.utils import _process_single_case -from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler, ToolCompletionCallback +from verl.workers.rollout.chat_scheduler import ( + ChatCompletionScheduler, + ToolCompletionCallback, +) def _get_free_port(): @@ -63,13 +66,18 @@ async def code_execution(self, request: Request): code = request_json["code"] print(f"execute code:\n{code}") - _, temp_file = tempfile.mkstemp(suffix=".py", prefix="temp_code", dir=None, text=True) + _, temp_file = tempfile.mkstemp( + suffix=".py", prefix="temp_code", dir=None, text=True + ) with open(temp_file, "w") as f: f.write(code) try: process = await asyncio.create_subprocess_exec( - sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + sys.executable, + temp_file, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await process.communicate() @@ -97,14 +105,18 @@ async def lifespan(app: fastapi.FastAPI): self.server_ready.set() yield - print("FastAPI shutdown, maybe address already in use, exit process immediately.") + print( + "FastAPI shutdown, maybe address already in use, exit process immediately." + ) os._exit(-1) app = fastapi.FastAPI(lifespan=lifespan) app.router.add_api_route("/run_code", self.code_execution, methods=["POST"]) self.port = _get_free_port() - config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") + config = uvicorn.Config( + app, host=["::", "0.0.0.0"], port=self.port, log_level="warning" + ) server = uvicorn.Server(config) await server.serve() @@ -120,13 +132,17 @@ def __init__(self, config: DictConfig, scheduler: ChatCompletionScheduler): self.max_assistant_turns = 16 self.answer_pattern = re.compile(r"(.*?)", re.DOTALL) - self.code_pattern = re.compile(r"\s*```python(.*?)```\s*", re.DOTALL) + self.code_pattern = re.compile( + r"\s*```python(.*?)```\s*", re.DOTALL + ) self.sandbox_fusion_url = config.reward_model.sandbox_fusion.url self.default_timeout = 10 self.memory_limit_mb = config.reward_model.sandbox_fusion.memory_limit_mb # TODO: support asyncio executor - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5)) + self.executor = concurrent.futures.ThreadPoolExecutor( + max_workers=max(32, os.cpu_count() * 5) + ) async def sandbox_code_execution(self, code: str) -> dict[str, Any]: loop = asyncio.get_running_loop() @@ -153,7 +169,12 @@ def extra_body(self): } return extra - async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]): + async def __call__( + self, + messages: list[dict[str, str]], + completions: ChatCompletion, + info: dict[str, Any], + ): role, content, finish_reason = ( completions.choices[0].message.role, completions.choices[0].message.content, @@ -164,24 +185,32 @@ async def __call__(self, messages: list[dict[str, str]], completions: ChatComple # STEP 0: check if we reach max turns if len(messages) >= self.max_assistant_turns: - print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max turns, done!") + print( + f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max turns, done!" + ) return # STEP 1: check if we reach max tokens if finish_reason == "length": - print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max tokens, done!") + print( + f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max tokens, done!" + ) return # STEP 2: check if we got answer matches = self.answer_pattern.findall(content) if matches: - print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Got answer: {matches[0]}, done!") + print( + f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Got answer: {matches[0]}, done!" + ) return # STEP 3: check if we got code block matches = self.code_pattern.findall(content) if not matches: - print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] No code block found, done!") + print( + f"[id={completions.id},turn={turn},finish_reason={finish_reason}] No code block found, done!" + ) return # STEP 4: execute code block in sandbox @@ -195,8 +224,12 @@ async def __call__(self, messages: list[dict[str, str]], completions: ChatComple return stdout, stderr = metadata["stdout"], metadata["stderr"] - messages.append({"role": "tool", "content": f"{stdout}{stderr}"}) - print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block executed, continue...") + messages.append( + {"role": "tool", "content": f"{stdout}{stderr}"} + ) + print( + f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block executed, continue..." + ) # STEP 5: resubmit chat completions with code block output self.scheduler.submit_chat_completions( @@ -273,7 +306,14 @@ async def __call__(self, messages: list[dict[str, str]], completions: ChatComple non_tensor_batch={ "raw_prompt": np.array( [ - [{"role": "user", "content": user_prompt_template.replace("{question}", problem)}] + [ + { + "role": "user", + "content": user_prompt_template.replace( + "{question}", problem + ), + } + ] for problem in dataset["Problem"] ] ), @@ -292,14 +332,20 @@ async def __call__(self, messages: list[dict[str, str]], completions: ChatComple tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path) responses = result.batch["responses"] response_mask = result.batch["response_mask"] - assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + assert ( + responses.size() == response_mask.size() + ), f"{responses.size()} != {response_mask.size()}" # Decode responses with response_mask for i in range(len(responses)): valid_tokens = responses[i][response_mask[i].bool()] response_str = tokenizer.decode(valid_tokens) - assert "" not in response_str, f"found in response: {response_str}" - assert "" not in response_str, f"found in response: {response_str}" + assert ( + "" not in response_str + ), f"found in response: {response_str}" + assert ( + "" not in response_str + ), f"found in response: {response_str}" print(f"response: {response_str}") print("Test passed!") diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_hf_rollout.py b/Agent0/executor_train/verl/tests/workers/rollout/test_hf_rollout.py index 3eb6f4b..fc1b3db 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_hf_rollout.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_hf_rollout.py @@ -18,7 +18,11 @@ from omegaconf import OmegaConf from torch.distributed.fsdp import CPUOffload, MixedPrecision from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType +from torch.distributed.fsdp.api import ( + ShardedStateDictConfig, + ShardingStrategy, + StateDictType, +) from transformers import AutoModelForCausalLM, AutoTokenizer from verl import DataProto @@ -52,10 +56,17 @@ def prepare_input_dataproto(tokenizer, config, validate): [{"role": "user", "content": "What's your name"}], ] formatted_prompts = [ - tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) + tokenizer.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=True + ) for conversation in preencode_prompts ] - prompts = tokenizer(formatted_prompts, return_tensors="pt", padding="max_length", max_length=config.prompt_length) + prompts = tokenizer( + formatted_prompts, + return_tensors="pt", + padding="max_length", + max_length=config.prompt_length, + ) input_dataproto = DataProto.from_dict( { "input_ids": prompts["input_ids"], @@ -75,9 +86,15 @@ def prepare_input_dataproto(tokenizer, config, validate): def prepare_fsdp_model(model, world_size): from torch.distributed.device_mesh import init_device_mesh - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + device_mesh = init_device_mesh( + "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + ) - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) fsdp_model = FSDP( model, @@ -92,7 +109,9 @@ def prepare_fsdp_model(model, world_size): ) FSDP.set_state_dict_type( - fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() + fsdp_model, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), ) return fsdp_model @@ -101,7 +120,9 @@ def test_hf_rollout(n: int = 1, do_sample: bool = True, validate: bool = False): config = OmegaConf.create(BASE_HF_ROLLOUT_CONFIG) config.update({"n": n, "do_sample": do_sample}) - assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." + assert ( + torch.cuda.device_count() >= 2 + ), "At least 2 GPUs is required to run tp+dp tests." local_rank, rank, world_size = initialize_global_process_group() # Initialize model and tokenizer @@ -109,17 +130,23 @@ def test_hf_rollout(n: int = 1, do_sample: bool = True, validate: bool = False): local_cache_path = os.path.expanduser(local_cache_path) hdfs_path = "Qwen/Qwen2-7B-Instruct" local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left", trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + local_model_path, padding_side="left", trust_remote_code=True + ) tokenizer.pad_token = tokenizer.eos_token # Initialize FSDP model - actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True) + actor_model = AutoModelForCausalLM.from_pretrained( + local_model_path, trust_remote_code=True + ) actor_model.to(torch.bfloat16) fsdp_model = prepare_fsdp_model(actor_model, world_size) # Initialize HFRollout and start generate hf_rollout = HFRollout(fsdp_model, OmegaConf.create(config)) - input = prepare_input_dataproto(tokenizer, config, validate).to(torch.cuda.current_device()) + input = prepare_input_dataproto(tokenizer, config, validate).to( + torch.cuda.current_device() + ) outputs = hf_rollout.generate_sequences(input) # check generated batch size is expected @@ -147,16 +174,22 @@ def test_hf_rollout(n: int = 1, do_sample: bool = True, validate: bool = False): # check response attention mask is expected response_attention = attention_mask[prompt_length:] - eos_positions = (outputs.batch["responses"][i] == tokenizer.pad_token_id).nonzero(as_tuple=True)[0] + eos_positions = ( + outputs.batch["responses"][i] == tokenizer.pad_token_id + ).nonzero(as_tuple=True)[0] if len(eos_positions) > 0: first_eos_pos = eos_positions[0].item() - assert response_attention[: first_eos_pos + 1].all(), "Response attention mask should be 1 until EOS" + assert response_attention[ + : first_eos_pos + 1 + ].all(), "Response attention mask should be 1 until EOS" if first_eos_pos + 1 < response_length: - assert not response_attention[first_eos_pos + 1 :].any(), ( - "Response attention mask should be 0 after EOS" - ) + assert not response_attention[ + first_eos_pos + 1 : + ].any(), "Response attention mask should be 0 after EOS" else: - assert response_attention.all(), "Response attention mask should be all 1 if no EOS token" + assert ( + response_attention.all() + ), "Response attention mask should be all 1 if no EOS token" # check response position ids is expected prompt_positions = position_ids[:prompt_length] diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py index 387de16..256ecc6 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py @@ -29,7 +29,11 @@ from verl.protocol import DataProto from verl.tools.mcp_search_tool import MCPSearchTool from verl.tools.utils.mcp_clients.McpClientManager import MCPClientManager -from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message +from verl.workers.rollout.schemas import ( + AsyncRolloutRequest, + AsyncRolloutRequestStateEnum, + Message, +) from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout DEFAULT_USER_CONTENT_PREFIX = ( @@ -100,10 +104,15 @@ def get_search_messages(): } # Mock search tool responses - tool_return_0_msg = {"role": "tool", "content": [{"type": "text", "text": "Today's weather in Beijing is sunny."}]} + tool_return_0_msg = { + "role": "tool", + "content": [{"type": "text", "text": "Today's weather in Beijing is sunny."}], + } tool_return_1_msg = { "role": "tool", - "content": [{"type": "text", "text": "Tomorrow's weather in Beijing is cloudy."}], + "content": [ + {"type": "text", "text": "Tomorrow's weather in Beijing is cloudy."} + ], } user_prompts = [user_prompt] @@ -133,11 +142,15 @@ def search_data(self, qwen_tokenizer): user_prompt, expect_turn_array, tool_return_array = get_search_messages() prompts = [[message] for message in user_prompt] preencode_turn_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) + qwen_tokenizer.apply_chat_template( + [turn], tokenize=False, add_generation_prompt=False + ) for turn in expect_turn_array ] preencode_tool_return_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) + qwen_tokenizer.apply_chat_template( + [turn], tokenize=False, add_generation_prompt=True + ) for turn in tool_return_array ] return prompts, preencode_turn_array, preencode_tool_return_array @@ -150,7 +163,11 @@ def search_rollout_config(self): tensor_parallel_size = 1 tool_path = "./resource/tool_configs/mcp_tool_config" rollout_config = get_rollout_config( - max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path + max_response_length, + max_prompt_length, + dtype, + tensor_parallel_size, + tool_path, ) return rollout_config @@ -158,10 +175,14 @@ def search_rollout_config(self): def search_data_proto(self, search_data, qwen_tokenizer): preencode_prompts, _, _ = search_data prompts = [ - qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + qwen_tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True + ) for message in preencode_prompts ] - input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000) + input_ids, attention_mask, position_ids = prepare_inputs( + qwen_tokenizer, prompts, 1000 + ) prompt_dict = TensorDict( { "input_ids": input_ids, @@ -176,7 +197,9 @@ def search_data_proto(self, search_data, qwen_tokenizer): [ { "tavily_search_tool": { - "create_kwargs": {"ground_truth": "Today is sunny and tomorrow will be cloudy in Beijing."}, + "create_kwargs": { + "ground_truth": "Today is sunny and tomorrow will be cloudy in Beijing." + }, }, } ], @@ -184,7 +207,12 @@ def search_data_proto(self, search_data, qwen_tokenizer): ) index = np.array([0], dtype=object) prompts = DataProto( - batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index} + batch=prompt_dict, + non_tensor_batch={ + "raw_prompt": messages, + "tools_kwargs": tools_kwargs, + "index": index, + }, ) return prompts @@ -263,7 +291,9 @@ def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config) } ] with ( - patch.object(MCPClientManager, "fetch_tool_schemas", return_value=tool_schema), + patch.object( + MCPClientManager, "fetch_tool_schemas", return_value=tool_schema + ), patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object(SGLangRollout, "_init_inference_engine", return_value=None), patch.object(SGLangRollout, "_init_sampling_params", return_value=None), @@ -293,14 +323,18 @@ def test_tools_registration(self, mock_rollout): assert mock_rollout._tool_call_parser_type == "qwen25" def test_rollout_req_creation(self, mock_rollout, search_data_proto): - req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1) + req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests( + search_data_proto, n=1 + ) assert len(req_list) == 1 assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING assert len(req_list[0].tool_schemas) == 1 def test_over_size_case(self, mock_rollout, search_data_proto, search_data): mock_rollout.config.multi_turn.max_assistant_turns = 1 - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + req = mock_rollout._preprocess_prompt_to_async_rollout_requests( + search_data_proto, n=1 + )[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() req_list = [req] @@ -327,7 +361,10 @@ def test_over_size_case(self, mock_rollout, search_data_proto, search_data): loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], + *[ + mock_rollout._async_rollout_a_request(req, True, False) + for req in req_list + ], ) ) assert len(output_req_list) == 1 @@ -343,13 +380,19 @@ def test_over_size_case(self, mock_rollout, search_data_proto, search_data): ) @patch.object(MCPSearchTool, "execute", new_callable=AsyncMock) - def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_proto, search_data): + def test_tool_call_basic_case( + self, mock_execute, mock_rollout, search_data_proto, search_data + ): _, expect_turn_array, tool_return_array = search_data # Mock search tool execution to return predefined responses - mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array] + mock_execute.side_effect = [ + (msg, 0.0, {"status": "success"}) for msg in tool_return_array + ] mock_rollout.config.multi_turn.max_assistant_turns = 10 - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + req = mock_rollout._preprocess_prompt_to_async_rollout_requests( + search_data_proto, n=1 + )[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() req_list = [req] @@ -362,7 +405,13 @@ def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_prot "text": turn, "meta_info": { "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "finish_reason": { + "type": ( + "tool_calls" + if idx < len(expect_turn_array) - 1 + else "stop" + ) + }, "prompt_tokens": len(turn), "completion_tokens": 100, "cached_tokens": 0, @@ -379,7 +428,12 @@ def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_prot loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( - asyncio.gather(*[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list]) + asyncio.gather( + *[ + mock_rollout._async_rollout_a_request(req, True, False) + for req in req_list + ] + ) ) # Verify conversation completed successfully with proper tool usage @@ -398,7 +452,9 @@ def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_prot assert search_counter == 2 @patch.object(MCPSearchTool, "execute", new_callable=AsyncMock) - def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_proto, search_data): + def test_tool_call_batch_case( + self, mock_execute, mock_rollout, search_data_proto, search_data + ): _, expect_turn_array, tool_return_array = search_data # Mock tool execution for large batch (100 requests * 2 calls each) mock_execute.side_effect = [ @@ -407,7 +463,9 @@ def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_prot ] * 100 mock_rollout.config.multi_turn.max_assistant_turns = 10 - base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests( + search_data_proto, n=1 + )[0] req_nums = 100 req_list = [] @@ -421,13 +479,21 @@ def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_prot req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest)) futures = [asyncio.Future() for _ in expect_turn_array] - for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)): + for idx, (fut, turn) in enumerate( + zip(futures, expect_turn_array, strict=True) + ): fut.set_result( { "text": turn, "meta_info": { "id": "dummy", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "finish_reason": { + "type": ( + "tool_calls" + if idx < len(expect_turn_array) - 1 + else "stop" + ) + }, "prompt_tokens": len(turn), "completion_tokens": 100, }, @@ -436,16 +502,27 @@ def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_prot req_turns_map[i] = futures req_turns_counter[i] = 0 - async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_kwargs): - fut = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]] + async def hacked_handle_engine_call( + self, _req: AsyncRolloutRequest, *_args, **_kwargs + ): + fut = req_turns_map[_req.batch_data_id][ + req_turns_counter[_req.batch_data_id] + ] req_turns_counter[_req.batch_data_id] += 1 return await fut - with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): + with patch.object( + SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call + ): mock_rollout._tp_rank = 0 loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( - asyncio.gather(*[mock_rollout._async_rollout_a_request(r, True, False) for r in req_list]) + asyncio.gather( + *[ + mock_rollout._async_rollout_a_request(r, True, False) + for r in req_list + ] + ) ) # Verify all requests completed successfully diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py index 47fefca..32607e5 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py @@ -25,7 +25,9 @@ ) -def _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False): +def _test_add_tool_response_messages_image_delta( + processor, image_list, description_list, resize_image=False +): assert len(image_list) == len(description_list) # Get the smallest dimensions across all images processed_images = [] @@ -45,9 +47,7 @@ def _test_add_tool_response_messages_image_delta(processor, image_list, descript processed_images = processed_images_resized # Initial message history - system_prompt = ( - "You will be provided with an image. Describe this image and then generate a new image for the next round" - ) + system_prompt = "You will be provided with an image. Describe this image and then generate a new image for the next round" messages = [ { "role": "system", @@ -109,7 +109,10 @@ def _test_add_tool_response_messages_image_delta(processor, image_list, descript _ = req.get_generation_prompt_ids(processor) req.add_assistant_message(processor, content=description_list[idx - 1]) before_tool_call_len = req.input_ids.shape[-1] - req.add_tool_response_messages(processor, [{"image": [img], "text": "Here is the new image you requested: "}]) + req.add_tool_response_messages( + processor, + [{"image": [img], "text": "Here is the new image you requested: "}], + ) after_tool_call_len = req.input_ids.shape[-1] if prev_generated_len == 0: prev_generated_len = after_tool_call_len - before_tool_call_len @@ -122,7 +125,9 @@ def _test_add_tool_response_messages_image_delta(processor, image_list, descript req.add_assistant_message(processor, content=description_list[-1]) messages = [msg.model_dump() for msg in req.messages] - tools = [tool.model_dump() for tool in req.tool_schemas] if req.tool_schemas else None + tools = ( + [tool.model_dump() for tool in req.tool_schemas] if req.tool_schemas else None + ) full_prompt_info = req._handle_apply_chat_template( processor, messages, @@ -146,16 +151,21 @@ def _test_add_tool_response_messages_image_delta(processor, image_list, descript @pytest.mark.skipif( - hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct" + hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, + reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct", ) def test_add_tool_response_messages_image_delta(): processor = hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") # From Qwen2.5-VL-3B-Instruct HF example - img_1_url = {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"} + img_1_url = { + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" + } img_1_description = "A woman sits on the beach at sunset, smiling as she shares a high five with her large dog." # GitHub Logo - img_2_url = {"image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"} + img_2_url = { + "image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" + } img_2_description = "A GitHub Logo image" # Octocat img_3_url = {"image": "https://octodex.github.com/images/orderedlistocat.png"} @@ -163,20 +173,27 @@ def test_add_tool_response_messages_image_delta(): image_list = [img_1_url, img_2_url, img_3_url] description_list = [img_1_description, img_2_description, img_3_description] - _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False) + _test_add_tool_response_messages_image_delta( + processor, image_list, description_list, resize_image=False + ) @pytest.mark.skipif( - hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct" + hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, + reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct", ) def test_add_tool_response_messages_image_delta_resize_image(): processor = hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") # From Qwen2.5-VL-3B-Instruct HF example - img_1_url = {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"} + img_1_url = { + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" + } img_1_description = "A woman sits on the beach at sunset, smiling as she shares a high five with her large dog." # GitHub Logo - img_2_url = {"image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"} + img_2_url = { + "image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" + } img_2_description = "A GitHub Logo image" # Octocat img_3_url = {"image": "https://octodex.github.com/images/orderedlistocat.png"} @@ -184,4 +201,6 @@ def test_add_tool_response_messages_image_delta_resize_image(): image_list = [img_1_url, img_2_url, img_3_url] description_list = [img_1_description, img_2_description, img_3_description] - _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=True) + _test_add_tool_response_messages_image_delta( + processor, image_list, description_list, resize_image=True + ) diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_search_tools.py b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_search_tools.py index 2400d5c..590e120 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_search_tools.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_search_tools.py @@ -33,7 +33,11 @@ OpenAIFunctionToolSchema, ) from verl.tools.search_tool import SearchTool -from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message +from verl.workers.rollout.schemas import ( + AsyncRolloutRequest, + AsyncRolloutRequestStateEnum, + Message, +) from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout DEFAULT_USER_CONTENT_PREFIX = ( @@ -58,14 +62,28 @@ def get_search_messages(): expect_turn_0_msg = { "role": "assistant", "content": "Let me search the web.", - "tool_calls": [{"type": "function", "function": {"name": "search", "arguments": {"query": "today's weather"}}}], + "tool_calls": [ + { + "type": "function", + "function": { + "name": "search", + "arguments": {"query": "today's weather"}, + }, + } + ], } expect_turn_1_msg = { "role": "assistant", "content": "Let me search again.", "tool_calls": [ - {"type": "function", "function": {"name": "search", "arguments": {"query": "tomorrow's weather"}}} + { + "type": "function", + "function": { + "name": "search", + "arguments": {"query": "tomorrow's weather"}, + }, + } ], } @@ -75,8 +93,14 @@ def get_search_messages(): } # Mock search tool responses - tool_return_0_msg = {"role": "tool", "content": "Today's weather in Beijing is sunny."} - tool_return_1_msg = {"role": "tool", "content": "Tomorrow's weather in Beijing is cloudy."} + tool_return_0_msg = { + "role": "tool", + "content": "Today's weather in Beijing is sunny.", + } + tool_return_1_msg = { + "role": "tool", + "content": "Tomorrow's weather in Beijing is cloudy.", + } user_prompts = [user_prompt] expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg] @@ -105,11 +129,15 @@ def search_data(self, qwen_tokenizer): user_prompt, expect_turn_array, tool_return_array = get_search_messages() prompts = [[message] for message in user_prompt] preencode_turn_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) + qwen_tokenizer.apply_chat_template( + [turn], tokenize=False, add_generation_prompt=False + ) for turn in expect_turn_array ] preencode_tool_return_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) + qwen_tokenizer.apply_chat_template( + [turn], tokenize=False, add_generation_prompt=True + ) for turn in tool_return_array ] return prompts, preencode_turn_array, preencode_tool_return_array @@ -122,7 +150,11 @@ def search_rollout_config(self): tensor_parallel_size = 1 tool_path = "./resource/tool_configs/search_tool_config" rollout_config = get_rollout_config( - max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path + max_response_length, + max_prompt_length, + dtype, + tensor_parallel_size, + tool_path, ) return rollout_config @@ -130,10 +162,14 @@ def search_rollout_config(self): def search_data_proto(self, search_data, qwen_tokenizer): preencode_prompts, _, _ = search_data prompts = [ - qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + qwen_tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True + ) for message in preencode_prompts ] - input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000) + input_ids, attention_mask, position_ids = prepare_inputs( + qwen_tokenizer, prompts, 1000 + ) prompt_dict = TensorDict( { "input_ids": input_ids, @@ -159,7 +195,12 @@ def search_data_proto(self, search_data, qwen_tokenizer): ) index = np.array([0], dtype=object) prompts = DataProto( - batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index} + batch=prompt_dict, + non_tensor_batch={ + "raw_prompt": messages, + "tools_kwargs": tools_kwargs, + "index": index, + }, ) return prompts @@ -190,7 +231,13 @@ def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config) @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_tools_registration( - self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config + self, + mock_env, + mock_engine, + mock_sampling, + search_rollout_config, + qwen_tokenizer, + qwen_model_config, ): rollout = SGLangRollout( actor_module="", @@ -225,7 +272,9 @@ def test_rollout_req_creation( processing_class=qwen_tokenizer, model_hf_config=qwen_model_config, ) - req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1) + req_list = rollout._preprocess_prompt_to_async_rollout_requests( + search_data_proto, n=1 + ) assert len(req_list) == 1 assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING assert len(req_list[0].tool_schemas) == 1 @@ -253,7 +302,9 @@ def test_rollout_req_creation( def test_over_size_case(self, mock_rollout, search_data_proto, search_data): mock_rollout.config.multi_turn.max_assistant_turns = 1 - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + req = mock_rollout._preprocess_prompt_to_async_rollout_requests( + search_data_proto, n=1 + )[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() req_list = [req] @@ -279,7 +330,10 @@ def test_over_size_case(self, mock_rollout, search_data_proto, search_data): loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], + *[ + mock_rollout._async_rollout_a_request(req, True, False) + for req in req_list + ], ) ) assert len(output_req_list) == 1 @@ -294,16 +348,22 @@ def test_over_size_case(self, mock_rollout, search_data_proto, search_data): ) @patch.object(SearchTool, "execute", new_callable=AsyncMock) - def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_proto, search_data): + def test_tool_call_basic_case( + self, mock_execute, mock_rollout, search_data_proto, search_data + ): _, expect_turn_array, tool_return_array = search_data # Mock search tool execution to return predefined responses - mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array] + mock_execute.side_effect = [ + (msg, 0.0, {"status": "success"}) for msg in tool_return_array + ] mock_rollout.config.multi_turn.max_assistant_turns = 10 mock_rollout._tool_map["search"].retrieval_service_url = "mock://dummy" - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + req = mock_rollout._preprocess_prompt_to_async_rollout_requests( + search_data_proto, n=1 + )[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() req_list = [req] @@ -316,7 +376,13 @@ def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_prot "text": turn, "meta_info": { "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "finish_reason": { + "type": ( + "tool_calls" + if idx < len(expect_turn_array) - 1 + else "stop" + ) + }, "prompt_tokens": len(turn), "completion_tokens": 100, "cached_tokens": 0, @@ -333,7 +399,12 @@ def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_prot loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( - asyncio.gather(*[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list]) + asyncio.gather( + *[ + mock_rollout._async_rollout_a_request(req, True, False) + for req in req_list + ] + ) ) # Verify conversation completed successfully with proper tool usage @@ -352,7 +423,9 @@ def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_prot assert search_counter == 2 @patch.object(SearchTool, "execute", new_callable=AsyncMock) - def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_proto, search_data): + def test_tool_call_batch_case( + self, mock_execute, mock_rollout, search_data_proto, search_data + ): _, expect_turn_array, tool_return_array = search_data # Mock tool execution for large batch (100 requests * 2 calls each) @@ -364,7 +437,9 @@ def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_prot mock_rollout.config.multi_turn.max_assistant_turns = 10 mock_rollout._tool_map["search"].retrieval_service_url = "mock://dummy" - base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] + base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests( + search_data_proto, n=1 + )[0] req_nums = 100 req_list = [] @@ -378,13 +453,21 @@ def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_prot req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest)) futures = [asyncio.Future() for _ in expect_turn_array] - for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)): + for idx, (fut, turn) in enumerate( + zip(futures, expect_turn_array, strict=True) + ): fut.set_result( { "text": turn, "meta_info": { "id": "dummy", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "finish_reason": { + "type": ( + "tool_calls" + if idx < len(expect_turn_array) - 1 + else "stop" + ) + }, "prompt_tokens": len(turn), "completion_tokens": 100, }, @@ -393,16 +476,27 @@ def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_prot req_turns_map[i] = futures req_turns_counter[i] = 0 - async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_kwargs): - fut = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]] + async def hacked_handle_engine_call( + self, _req: AsyncRolloutRequest, *_args, **_kwargs + ): + fut = req_turns_map[_req.batch_data_id][ + req_turns_counter[_req.batch_data_id] + ] req_turns_counter[_req.batch_data_id] += 1 return await fut - with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): + with patch.object( + SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call + ): mock_rollout._tp_rank = 0 loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( - asyncio.gather(*[mock_rollout._async_rollout_a_request(r, True, False) for r in req_list]) + asyncio.gather( + *[ + mock_rollout._async_rollout_a_request(r, True, False) + for r in req_list + ] + ) ) # Verify all requests completed successfully diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py index 3f30929..4e7b227 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py @@ -38,7 +38,11 @@ OpenAIFunctionSchema, OpenAIFunctionToolSchema, ) -from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message +from verl.workers.rollout.schemas import ( + AsyncRolloutRequest, + AsyncRolloutRequestStateEnum, + Message, +) from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout sandbox_url = "" @@ -163,14 +167,20 @@ def qwen_model_config(self): @pytest.fixture def sandbox_fusion_data(self, qwen_tokenizer): - user_prompt, expect_turn_array, tool_return_array = get_sandbox_fusion_messages() + user_prompt, expect_turn_array, tool_return_array = ( + get_sandbox_fusion_messages() + ) prompts = [[message] for message in user_prompt] preencode_turn_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) + qwen_tokenizer.apply_chat_template( + [turn], tokenize=False, add_generation_prompt=False + ) for turn in expect_turn_array ] preencode_tool_return_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) + qwen_tokenizer.apply_chat_template( + [turn], tokenize=False, add_generation_prompt=True + ) for turn in tool_return_array ] return prompts, preencode_turn_array, preencode_tool_return_array @@ -183,7 +193,11 @@ def sandbox_fusion_rollout_config(self): tensor_parallel_size = 1 tool_path = "./resource/tool_configs/sandbox_fusion_tool_config" rollout_config = get_rollout_config( - max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path + max_response_length, + max_prompt_length, + dtype, + tensor_parallel_size, + tool_path, ) return rollout_config @@ -191,10 +205,14 @@ def sandbox_fusion_rollout_config(self): def sandbox_data_proto(self, sandbox_fusion_data, qwen_tokenizer): preencode_prompts, _, _ = sandbox_fusion_data prompts = [ - qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + qwen_tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True + ) for message in preencode_prompts ] - input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000) + input_ids, attention_mask, position_ids = prepare_inputs( + qwen_tokenizer, prompts, 1000 + ) prompt_dict = TensorDict( { "input_ids": input_ids, @@ -216,16 +234,27 @@ def sandbox_data_proto(self, sandbox_fusion_data, qwen_tokenizer): ) index = np.array([0], dtype=object) prompts = DataProto( - batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index} + batch=prompt_dict, + non_tensor_batch={ + "raw_prompt": messages, + "tools_kwargs": tools_kwargs, + "index": index, + }, ) return prompts @pytest.fixture - def mock_rollout(self, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config): + def mock_rollout( + self, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config + ): """Mock the rollout instance""" - with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object( + with patch.object( + SGLangRollout, "_init_distributed_env", return_value=None + ), patch.object( SGLangRollout, "_init_inference_engine", return_value=None - ), patch.object(SGLangRollout, "_init_sampling_params", return_value=None): + ), patch.object( + SGLangRollout, "_init_sampling_params", return_value=None + ): rollout = SGLangRollout( actor_module="", config=sandbox_fusion_rollout_config, @@ -253,7 +282,9 @@ def test_tools_registration(self, mock_rollout): def test_rollout_req_creation(self, mock_rollout, sandbox_data_proto): """Test request creation functionality""" - req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1) + req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests( + sandbox_data_proto, n=1 + ) assert len(req_list) == 1 assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING assert len(req_list[0].tool_schemas) == 1 @@ -278,10 +309,14 @@ def test_rollout_req_creation(self, mock_rollout, sandbox_data_proto): ), ) - def test_over_size_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data): + def test_over_size_case( + self, mock_rollout, sandbox_data_proto, sandbox_fusion_data + ): """Test over-size response truncation case""" mock_rollout.config.multi_turn.max_assistant_turns = 1 - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] + req = mock_rollout._preprocess_prompt_to_async_rollout_requests( + sandbox_data_proto, n=1 + )[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() req_list = [req] @@ -308,7 +343,10 @@ def test_over_size_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_d loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], + *[ + mock_rollout._async_rollout_a_request(req, True, False) + for req in req_list + ], ) ) assert len(output_req_list) == 1 @@ -324,11 +362,15 @@ def test_over_size_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_d ) @skip_if_valid_sandbox(sandbox_url) - def test_tool_call_basic_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data): + def test_tool_call_basic_case( + self, mock_rollout, sandbox_data_proto, sandbox_fusion_data + ): """Test basic tool call case""" mock_rollout.config.multi_turn.max_assistant_turns = 10 mock_rollout._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] + req = mock_rollout._preprocess_prompt_to_async_rollout_requests( + sandbox_data_proto, n=1 + )[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() req_list = [req] @@ -342,7 +384,13 @@ def test_tool_call_basic_case(self, mock_rollout, sandbox_data_proto, sandbox_fu "text": turn, "meta_info": { "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "finish_reason": { + "type": ( + "tool_calls" + if idx < len(expect_turn_array) - 1 + else "stop" + ) + }, "prompt_tokens": len(turn), "completion_tokens": 100, "cached_tokens": 0, @@ -359,7 +407,10 @@ def test_tool_call_basic_case(self, mock_rollout, sandbox_data_proto, sandbox_fu loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], + *[ + mock_rollout._async_rollout_a_request(req, True, False) + for req in req_list + ], ) ) assert len(output_req_list) == 1 @@ -377,11 +428,15 @@ def test_tool_call_basic_case(self, mock_rollout, sandbox_data_proto, sandbox_fu assert code_counter == 2 @skip_if_valid_sandbox(sandbox_url) - def test_tool_call_batch_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data): + def test_tool_call_batch_case( + self, mock_rollout, sandbox_data_proto, sandbox_fusion_data + ): """Test batch tool call case""" mock_rollout.config.multi_turn.max_assistant_turns = 10 mock_rollout._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] + req = mock_rollout._preprocess_prompt_to_async_rollout_requests( + sandbox_data_proto, n=1 + )[0] req_nums = 100 req_list = [] req_turns_counter = {} @@ -400,7 +455,13 @@ def test_tool_call_batch_case(self, mock_rollout, sandbox_data_proto, sandbox_fu "text": turn, "meta_info": { "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, + "finish_reason": { + "type": ( + "tool_calls" + if idx < len(expect_turn_array) - 1 + else "stop" + ) + }, "prompt_tokens": len(turn), "completion_tokens": 100, "cached_tokens": 0, @@ -415,19 +476,30 @@ def test_tool_call_batch_case(self, mock_rollout, sandbox_data_proto, sandbox_fu req_turns_counter[_temp_req.batch_data_id] = 0 async def hacked_handle_engine_call( - self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs + self, + _req: AsyncRolloutRequest, + do_sample: bool, + is_validate: bool, + **kwargs, ): - result = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]] + result = req_turns_map[_req.batch_data_id][ + req_turns_counter[_req.batch_data_id] + ] req_turns_counter[_req.batch_data_id] += 1 re = await result return re - with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): + with patch.object( + SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call + ): mock_rollout._tp_rank = 0 loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], + *[ + mock_rollout._async_rollout_a_request(req, True, False) + for req in req_list + ], ) ) assert len(output_req_list) == req_nums @@ -562,9 +634,14 @@ def test_rate_limiter(self): # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=3) exec_worker = init_execution_pool( - num_workers=10, enable_global_rate_limit=True, rate_limit=3, mode=PoolMode.ThreadMode + num_workers=10, + enable_global_rate_limit=True, + rate_limit=3, + mode=PoolMode.ThreadMode, + ) + center = TestActor.options(get_if_exists=True, name="test-actor").remote( + self.rank, self.world_size ) - center = TestActor.options(get_if_exists=True, name="test-actor").remote(self.rank, self.world_size) ray.get(exec_worker.ping.remote()) def fn(i): @@ -594,7 +671,10 @@ def test_rotten_execution(self): # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=6) exec_worker = init_execution_pool( - num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode + num_workers=10, + enable_global_rate_limit=True, + rate_limit=6, + mode=PoolMode.ThreadMode, ) ray.get(exec_worker.ping.remote()) @@ -609,8 +689,12 @@ def fn(i): results = loop.run_until_complete(asyncio.gather(*tasks)) expect_result = [None] + list(range(10)) + list(range(11, 20)) sorted_data = sorted(results, key=lambda x: (x is not None, x)) - assert sorted_data == expect_result, f"results: {results}, expect_result: {expect_result}" - rate_limiter = TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote() + assert ( + sorted_data == expect_result + ), f"results: {results}, expect_result: {expect_result}" + rate_limiter = TokenBucketWorker.options( + name="rate-limiter", get_if_exists=True + ).remote() rate = ray.get(rate_limiter.get_current_count.remote()) assert rate == 0, f"rate: {rate}" @@ -626,9 +710,14 @@ def test_rate_limiter(self): # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=6) exec_worker = init_execution_pool( - num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode + num_workers=10, + enable_global_rate_limit=True, + rate_limit=6, + mode=PoolMode.ThreadMode, + ) + center = TestActor.options(get_if_exists=True, name="test-actor").remote( + self.rank, self.world_size ) - center = TestActor.options(get_if_exists=True, name="test-actor").remote(self.rank, self.world_size) ray.get(exec_worker.ping.remote()) def fn(i): diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py index 3ccde18..0fe6680 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py @@ -61,21 +61,43 @@ def test_async_sglang_rollout_w_interaction(): ] ] interaction_kwargs = [ - {"name": "gsm8k", "query": "Who won the Champions League in 2019?", "ground_truth": "Real Madrid"}, - {"name": "gsm8k", "query": "The founder of Apple is", "ground_truth": "Steve Jobs"}, - {"name": "gsm8k", "query": "What's the best way to learn python?", "ground_truth": "Learn python from scratch"}, + { + "name": "gsm8k", + "query": "Who won the Champions League in 2019?", + "ground_truth": "Real Madrid", + }, + { + "name": "gsm8k", + "query": "The founder of Apple is", + "ground_truth": "Steve Jobs", + }, + { + "name": "gsm8k", + "query": "What's the best way to learn python?", + "ground_truth": "Learn python from scratch", + }, ] prompts = [ - tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True + ) for message in preencode_prompts ] - input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length) + input_ids, attention_mask, position_ids = prepare_inputs( + tokenizer, prompts, max_prompt_length + ) - hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) + hf_response_tokens = generate_hf_output( + actor_model, input_ids, attention_mask, tokenizer, max_response_length + ) - fsdp_device_mesh = init_device_mesh("cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",)) + fsdp_device_mesh = init_device_mesh( + "cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",) + ) inference_device_mesh_cpu = init_device_mesh( - "cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=("dp", "infer_tp", "pp") + "cpu", + mesh_shape=(1, tensor_parallel_size, 1), + mesh_dim_names=("dp", "infer_tp", "pp"), ) fsdp_model = FSDP( @@ -94,7 +116,11 @@ def test_async_sglang_rollout_w_interaction(): interaction_config = { "interaction": [ - {"name": "gsm8k", "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}} + { + "name": "gsm8k", + "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", + "config": {}, + } ] } @@ -103,7 +129,12 @@ def test_async_sglang_rollout_w_interaction(): interaction_config_path = f.name rollout_config = get_rollout_config( - max_response_length, max_prompt_length, dtype, tensor_parallel_size, None, interaction_config_path + max_response_length, + max_prompt_length, + dtype, + tensor_parallel_size, + None, + interaction_config_path, ) rollout = SGLangRollout( actor_module=local_model_path, @@ -135,7 +166,10 @@ def test_async_sglang_rollout_w_interaction(): messages = np.asarray(preencode_prompts) prompts = DataProto( batch=prompt_dict, - non_tensor_batch={"raw_prompt": messages, "interaction_kwargs": np.asarray(interaction_kwargs)}, + non_tensor_batch={ + "raw_prompt": messages, + "interaction_kwargs": np.asarray(interaction_kwargs), + }, ) prompts.meta_info.update( diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_w_tools.py b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_w_tools.py index 20faab8..753e046 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_w_tools.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_async_rollout_w_tools.py @@ -61,16 +61,26 @@ def test_async_sglang_rollout_w_tool(): ] ] prompts = [ - tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True + ) for message in preencode_prompts ] - input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length) + input_ids, attention_mask, position_ids = prepare_inputs( + tokenizer, prompts, max_prompt_length + ) - hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) + hf_response_tokens = generate_hf_output( + actor_model, input_ids, attention_mask, tokenizer, max_response_length + ) - fsdp_device_mesh = init_device_mesh("cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",)) + fsdp_device_mesh = init_device_mesh( + "cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",) + ) inference_device_mesh_cpu = init_device_mesh( - "cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=("dp", "infer_tp", "pp") + "cpu", + mesh_shape=(1, tensor_parallel_size, 1), + mesh_dim_names=("dp", "infer_tp", "pp"), ) fsdp_model = FSDP( diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_multi_interaction.py b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_multi_interaction.py index 465470f..4cb5b05 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_multi_interaction.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_multi_interaction.py @@ -119,11 +119,15 @@ def test_initialize_multiple_interactions(self): # Mock SGLang engine and initialization methods like the reference test with ( patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object( + SGLangRollout, "_init_inference_engine", return_value=None + ), patch.object(SGLangRollout, "_init_sampling_params", return_value=None), ): # Create a real tokenizer like the reference test - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B", padding_side="left" + ) tokenizer.pad_token = tokenizer.eos_token # Mock model config @@ -154,12 +158,22 @@ def test_initialize_multiple_interactions(self): assert "mock_agent2" in rollout.interaction_map # Use class name comparison instead of isinstance for multi-process compatibility - assert rollout.interaction_map["mock_agent1"].__class__.__name__ == "MockInteraction" - assert rollout.interaction_map["mock_agent2"].__class__.__name__ == "MockInteraction" + assert ( + rollout.interaction_map["mock_agent1"].__class__.__name__ + == "MockInteraction" + ) + assert ( + rollout.interaction_map["mock_agent2"].__class__.__name__ + == "MockInteraction" + ) # Also check that they are instances of BaseInteraction (which should work across processes) - assert isinstance(rollout.interaction_map["mock_agent1"], BaseInteraction) - assert isinstance(rollout.interaction_map["mock_agent2"], BaseInteraction) + assert isinstance( + rollout.interaction_map["mock_agent1"], BaseInteraction + ) + assert isinstance( + rollout.interaction_map["mock_agent2"], BaseInteraction + ) # Check that names were set correctly assert rollout.interaction_map["mock_agent1"].name == "mock_agent1" @@ -176,10 +190,14 @@ def test_interaction_selection_by_name(self): try: with ( patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object( + SGLangRollout, "_init_inference_engine", return_value=None + ), patch.object(SGLangRollout, "_init_sampling_params", return_value=None), ): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B", padding_side="left" + ) tokenizer.pad_token = tokenizer.eos_token mock_model_config = MagicMock() @@ -201,7 +219,11 @@ def test_interaction_selection_by_name(self): ) # Test interaction selection logic - from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message + from verl.workers.rollout.schemas import ( + AsyncRolloutRequest, + AsyncRolloutRequestStateEnum, + Message, + ) # Create a mock request with specific interaction name req = AsyncRolloutRequest( @@ -288,10 +310,14 @@ def test_fallback_to_default_interaction(self): try: with ( patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object( + SGLangRollout, "_init_inference_engine", return_value=None + ), patch.object(SGLangRollout, "_init_sampling_params", return_value=None), ): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B", padding_side="left" + ) tokenizer.pad_token = tokenizer.eos_token mock_model_config = MagicMock() @@ -329,10 +355,14 @@ def test_error_on_missing_interaction(self): try: with ( patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), + patch.object( + SGLangRollout, "_init_inference_engine", return_value=None + ), patch.object(SGLangRollout, "_init_sampling_params", return_value=None), ): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B", padding_side="left" + ) tokenizer.pad_token = tokenizer.eos_token mock_model_config = MagicMock() @@ -401,7 +431,9 @@ def test_backward_compatibility_no_interaction_config(self): patch.object(SGLangRollout, "_init_inference_engine", return_value=None), patch.object(SGLangRollout, "_init_sampling_params", return_value=None), ): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + "Qwen/Qwen2.5-0.5B", padding_side="left" + ) tokenizer.pad_token = tokenizer.eos_token mock_model_config = MagicMock() diff --git a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_spmd.py b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_spmd.py index e6b7256..35034ab 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_spmd.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/test_sglang_spmd.py @@ -35,7 +35,9 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor): - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][ + 0 + ] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids @@ -51,14 +53,24 @@ def test_sglang_spmd(): local_model_path = "Qwen/Qwen2.5-0.5B" tokenizer, actor_model = load_tokenizer_and_model(local_model_path) - preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"] - input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length) + preencode_prompts = [ + "Who won the Champions League in 2019?", + "The founder of Apple is", + "What's your name?", + ] + input_ids, attention_mask, _ = prepare_inputs( + tokenizer, preencode_prompts, max_prompt_length + ) - hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) + hf_response_tokens = generate_hf_output( + actor_model, input_ids, attention_mask, tokenizer, max_response_length + ) tensor_parallel_size = 2 inference_device_mesh_cpu = init_device_mesh( - "cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"] + "cpu", + mesh_shape=(1, tensor_parallel_size, 1), + mesh_dim_names=["dp", "tp", "pp"], ) tp_rank = inference_device_mesh_cpu["tp"].get_local_rank() @@ -74,7 +86,11 @@ def test_sglang_spmd(): input_ids = input_ids.cuda() idx_list = [] - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + pad_token_id = ( + tokenizer.pad_token_id + if tokenizer.pad_token_id is not None + else tokenizer.eos_token_id + ) for i in range(input_ids.shape[0]): idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) @@ -93,7 +109,9 @@ def test_sglang_spmd(): ) loop = asyncio.get_event_loop() - outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params)) + outputs = loop.run_until_complete( + llm.async_generate(input_ids=idx_list, sampling_params=sampling_params) + ) else: outputs = None @@ -108,7 +126,9 @@ def test_sglang_spmd(): sglang_response_tokens = [output["text"] for output in outputs] print(f"sglang response: {sglang_response_tokens}") - assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n" + assert are_lists_similar( + hf_response_tokens, sglang_response_tokens + ), "Strings differ more than 10%:\n" print("SPMD Test Passed!") torch.distributed.barrier() diff --git a/Agent0/executor_train/verl/tests/workers/rollout/utils_sglang.py b/Agent0/executor_train/verl/tests/workers/rollout/utils_sglang.py index 2e22e47..eb204a2 100644 --- a/Agent0/executor_train/verl/tests/workers/rollout/utils_sglang.py +++ b/Agent0/executor_train/verl/tests/workers/rollout/utils_sglang.py @@ -88,23 +88,35 @@ def clean_torchelastic_env(): def load_tokenizer_and_model(local_model_path, dtype="bfloat16"): tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained(local_model_path, torch_dtype=getattr(torch, dtype), device_map="cuda") + model = AutoModelForCausalLM.from_pretrained( + local_model_path, torch_dtype=getattr(torch, dtype), device_map="cuda" + ) return tokenizer, model def prepare_inputs(tokenizer, prompts, max_prompt_length): - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + pad_token_id = ( + tokenizer.pad_token_id + if tokenizer.pad_token_id is not None + else tokenizer.eos_token_id + ) tokenized = tokenizer(prompts, return_tensors="pt", padding=True) - input_ids = pad_sequence_to_length(tokenized["input_ids"], max_prompt_length, pad_token_id, left_pad=True) + input_ids = pad_sequence_to_length( + tokenized["input_ids"], max_prompt_length, pad_token_id, left_pad=True + ) attention_mask = pad_sequence_to_length( tokenized["attention_mask"], max_prompt_length, pad_token_id=0, left_pad=True ) position_ids = compute_position_id_with_mask(attention_mask) - position_ids = pad_sequence_to_length(position_ids, max_prompt_length, pad_token_id=0, left_pad=True) + position_ids = pad_sequence_to_length( + position_ids, max_prompt_length, pad_token_id=0, left_pad=True + ) return input_ids, attention_mask, position_ids -def generate_hf_output(model, input_ids, attention_mask, tokenizer, max_response_length): +def generate_hf_output( + model, input_ids, attention_mask, tokenizer, max_response_length +): generation_config = GenerationConfig(do_sample=False) output = model.generate( input_ids=input_ids.cuda(), diff --git a/Agent0/executor_train/verl/verl/__init__.py b/Agent0/executor_train/verl/verl/__init__.py index 593f3dc..65788c3 100644 --- a/Agent0/executor_train/verl/verl/__init__.py +++ b/Agent0/executor_train/verl/verl/__init__.py @@ -37,7 +37,9 @@ if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true": if importlib.util.find_spec("modelscope") is None: - raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`") + raise ImportError( + "You are using the modelscope hub, please install modelscope by `pip install modelscope -U`" + ) # Patch hub to download models from modelscope to speed up. from modelscope.utils.hf_util import patch_hub diff --git a/Agent0/executor_train/verl/verl/experimental/agent_loop/agent_loop.py b/Agent0/executor_train/verl/verl/experimental/agent_loop/agent_loop.py index e16f1a8..f0ad869 100644 --- a/Agent0/executor_train/verl/verl/experimental/agent_loop/agent_loop.py +++ b/Agent0/executor_train/verl/verl/experimental/agent_loop/agent_loop.py @@ -32,7 +32,11 @@ from verl.single_controller.ray.base import RayWorkerGroup from verl.utils import hf_tokenizer from verl.utils.fs import copy_to_local -from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op +from verl.utils.rollout_trace import ( + RolloutTraceConfig, + rollout_trace_attr, + rollout_trace_op, +) from verl.workers.rollout.async_server import async_server_class logger = logging.getLogger(__file__) @@ -46,7 +50,12 @@ class AsyncLLMServerManager: - Sticky session: send multi-turn chat completions to same server for automatic prefix caching """ - def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000): + def __init__( + self, + config: DictConfig, + server_handles: list[ray.actor.ActorHandle], + max_cache_size: int = 10000, + ): """Initialize the AsyncLLMServerManager. Args: @@ -59,7 +68,9 @@ def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandl random.shuffle(self.server_handles) # Least requests load balancing - self.weighted_serveres = [[0, (hash(server), server)] for server in server_handles] + self.weighted_serveres = [ + [0, (hash(server), server)] for server in server_handles + ] heapq.heapify(self.weighted_serveres) # LRU cache to map request_id to server @@ -126,7 +137,12 @@ class AgentLoopBase(ABC): _class_initialized = False - def __init__(self, config: DictConfig, server_manager: AsyncLLMServerManager, tokenizer: AutoTokenizer): + def __init__( + self, + config: DictConfig, + server_manager: AsyncLLMServerManager, + tokenizer: AutoTokenizer, + ): """Initialize agent loop. Args: @@ -148,7 +164,9 @@ def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer): cls._class_initialized = True @abstractmethod - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + async def run( + self, messages: list[dict[str, Any]], sampling_params: dict[str, Any] + ) -> AgentLoopOutput: """Run agent loop to interact with LLM server and environment. Args: @@ -224,7 +242,9 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: # by default, we assume it's a single turn agent if "agent_name" not in batch.non_tensor_batch: - batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object) + batch.non_tensor_batch["agent_name"] = np.array( + ["single_turn_agent"] * len(batch), dtype=object + ) tasks = [] agent_names = batch.non_tensor_batch["agent_name"] @@ -234,11 +254,19 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: else: index = np.arange(len(raw_prompts)) - trajectory_info = await get_trajectory_info(batch.meta_info.get("global_steps", -1), index) + trajectory_info = await get_trajectory_info( + batch.meta_info.get("global_steps", -1), index + ) - for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True): + for agent_name, messages, trajectory in zip( + agent_names, raw_prompts, trajectory_info, strict=True + ): tasks.append( - asyncio.create_task(self._run_agent_loop(agent_name, messages.tolist(), sampling_params, trajectory)) + asyncio.create_task( + self._run_agent_loop( + agent_name, messages.tolist(), sampling_params, trajectory + ) + ) ) outputs = await asyncio.gather(*tasks) @@ -253,10 +281,14 @@ async def _run_agent_loop( trajectory: dict[str, Any], ) -> AgentLoopOutput: with rollout_trace_attr( - step=trajectory["step"], sample_index=trajectory["sample_index"], rollout_n=trajectory["rollout_n"] + step=trajectory["step"], + sample_index=trajectory["sample_index"], + rollout_n=trajectory["rollout_n"], ): agent_loop_class = self.get_agent_loop_class(agent_name) - agent_loop = agent_loop_class(self.config, self.server_manager, self.tokenizer) + agent_loop = agent_loop_class( + self.config, self.server_manager, self.tokenizer + ) output = await agent_loop.run(messages, sampling_params) return output @@ -276,7 +308,9 @@ def get_agent_loop_class(self, agent_name: str) -> type[AgentLoopBase]: ValueError: If the agent_name is not recognized. """ # TODO: add tool agent registrary - from verl.experimental.agent_loop.single_turn_agent_loop import SingleTurnAgentLoop + from verl.experimental.agent_loop.single_turn_agent_loop import ( + SingleTurnAgentLoop, + ) from verl.experimental.agent_loop.tool_agent_loop import ToolAgentLoop if agent_name == "single_turn_agent": @@ -302,7 +336,10 @@ def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto: return_tensors="pt", return_attention_mask=True, ) - prompt_ids, prompt_attention_mask = outputs["input_ids"], outputs["attention_mask"] + prompt_ids, prompt_attention_mask = ( + outputs["input_ids"], + outputs["attention_mask"], + ) # responses self.tokenizer.padding_side = "right" @@ -313,7 +350,10 @@ def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto: return_tensors="pt", return_attention_mask=True, ) - response_ids, response_attention_mask = outputs["input_ids"], outputs["attention_mask"] + response_ids, response_attention_mask = ( + outputs["input_ids"], + outputs["attention_mask"], + ) # response_mask outputs = self.tokenizer.pad( @@ -324,13 +364,15 @@ def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto: return_attention_mask=False, ) response_mask = outputs["input_ids"] - assert response_ids.shape == response_mask.shape, ( - f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}" - ) + assert ( + response_ids.shape == response_mask.shape + ), f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}" response_mask = response_mask * response_attention_mask input_ids = torch.cat([prompt_ids, response_ids], dim=1) - attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1) + attention_mask = torch.cat( + [prompt_attention_mask, response_attention_mask], dim=1 + ) position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask batch = TensorDict( @@ -347,7 +389,11 @@ def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto: num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32) metrics = [input.metrics.model_dump() for input in inputs] - return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}, meta_info={"metrics": metrics}) + return DataProto( + batch=batch, + non_tensor_batch={"__num_turns__": num_turns}, + meta_info={"metrics": metrics}, + ) async def get_trajectory_info(step, index): @@ -359,7 +405,9 @@ async def get_trajectory_info(step, index): rollout_n += 1 else: rollout_n = 0 - trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n}) + trajectory_info.append( + {"step": step, "sample_index": index[i], "rollout_n": rollout_n} + ) return trajectory_info @@ -383,10 +431,14 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): self.sleep() def _initialize_llm_servers(self): - self.rollout_tp_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size + self.rollout_tp_size = ( + self.config.actor_rollout_ref.rollout.tensor_model_parallel_size + ) self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size - register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center") + register_center = ray.get_actor( + f"{self.worker_group.name_prefix}_register_center" + ) workers_info = ray.get(register_center.get_worker_info.remote()) assert len(workers_info) == self.worker_group.world_size @@ -400,7 +452,9 @@ def _initialize_llm_servers(self): rollout_backend_class=self.config.actor_rollout_ref.rollout.agent.custom_async_server.name, ) else: - server_class = async_server_class(rollout_backend=self.config.actor_rollout_ref.rollout.name) + server_class = async_server_class( + rollout_backend=self.config.actor_rollout_ref.rollout.name + ) # Start all server instances, restart if address already in use. unready_dp_ranks = set(range(self.rollout_dp_size)) @@ -413,7 +467,12 @@ def _initialize_llm_servers(self): soft=False, ), name=f"async_llm_server_{rollout_dp_rank}", - ).remote(self.config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) + ).remote( + self.config, + self.rollout_dp_size, + rollout_dp_rank, + self.worker_group.name_prefix, + ) for rollout_dp_rank in unready_dp_ranks } @@ -425,7 +484,9 @@ def _initialize_llm_servers(self): unready_dp_ranks.remove(rollout_dp_rank) except Exception: ray.kill(server) - print(f"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting...") + print( + f"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting..." + ) # All server instances are ready, init AsyncLLM engine. ray.get([server.init_engine.remote() for server in self.async_llm_servers]) @@ -462,16 +523,24 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: self.sleep() # calculate performance metrics - metrics = [output.meta_info["metrics"] for output in outputs] # List[List[Dict[str, str]]] + metrics = [ + output.meta_info["metrics"] for output in outputs + ] # List[List[Dict[str, str]]] timing = self._performance_metrics(metrics, output) output.meta_info = {"timing": timing} return output - def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]: + def _performance_metrics( + self, metrics: list[list[dict[str, str]]], output: DataProto + ) -> dict[str, float]: timing = {} - t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk]) - t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk]) + t_generate_sequences = np.array( + [metric["generate_sequences"] for chunk in metrics for metric in chunk] + ) + t_tool_calls = np.array( + [metric["tool_calls"] for chunk in metrics for metric in chunk] + ) timing["agent_loop/generate_sequences/min"] = t_generate_sequences.min() timing["agent_loop/generate_sequences/max"] = t_generate_sequences.max() timing["agent_loop/generate_sequences/mean"] = t_generate_sequences.mean() @@ -485,8 +554,12 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Data prompt_length = output.batch["prompts"].shape[1] timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest] timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest] - timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item() - timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item() + timing["agent_loop/slowest/prompt_length"] = ( + attention_mask[:prompt_length].sum().item() + ) + timing["agent_loop/slowest/response_length"] = ( + attention_mask[prompt_length:].sum().item() + ) return timing diff --git a/Agent0/executor_train/verl/verl/experimental/agent_loop/single_turn_agent_loop.py b/Agent0/executor_train/verl/verl/experimental/agent_loop/single_turn_agent_loop.py index e4021ef..d6a9df8 100644 --- a/Agent0/executor_train/verl/verl/experimental/agent_loop/single_turn_agent_loop.py +++ b/Agent0/executor_train/verl/verl/experimental/agent_loop/single_turn_agent_loop.py @@ -31,16 +31,23 @@ def __init__(self, config, server_manager, tokenizer): self.prompt_length = config.actor_rollout_ref.rollout.prompt_length self.response_length = config.actor_rollout_ref.rollout.response_length - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + async def run( + self, messages: list[dict[str, Any]], sampling_params: dict[str, Any] + ) -> AgentLoopOutput: metrics = {} request_id = uuid4().hex prompt_ids = await self.loop.run_in_executor( - None, lambda: self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) + None, + lambda: self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True + ), ) with simple_timer("generate_sequences", metrics): response_ids = await self.server_manager.generate( - request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params + request_id=request_id, + prompt_ids=prompt_ids, + sampling_params=sampling_params, ) response_mask = [1] * len(response_ids) diff --git a/Agent0/executor_train/verl/verl/experimental/agent_loop/tool_agent_loop.py b/Agent0/executor_train/verl/verl/experimental/agent_loop/tool_agent_loop.py index 2756668..caf00ed 100644 --- a/Agent0/executor_train/verl/verl/experimental/agent_loop/tool_agent_loop.py +++ b/Agent0/executor_train/verl/verl/experimental/agent_loop/tool_agent_loop.py @@ -72,7 +72,10 @@ def __init__(self, tokenizer) -> None: async def extract_tool_calls(self, responses_ids: list[int]) -> list[FunctionCall]: loop = asyncio.get_running_loop() text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids) - if self.tool_call_start_token not in text or self.tool_call_end_token not in text: + if ( + self.tool_call_start_token not in text + or self.tool_call_end_token not in text + ): return [] matches = self.tool_call_regex.findall(text) @@ -81,7 +84,11 @@ async def extract_tool_calls(self, responses_ids: list[int]) -> list[FunctionCal try: function_call = json.loads(match) name, arguments = function_call["name"], function_call["arguments"] - function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False))) + function_calls.append( + FunctionCall( + name=name, arguments=json.dumps(arguments, ensure_ascii=False) + ) + ) except Exception as e: logger.error(f"Failed to decode tool call: {e}") return function_calls @@ -101,29 +108,51 @@ def init_class(cls, config, tokenizer): # Initialize tools from config file cls.tokenizer = tokenizer cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns - cls.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns - cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls - cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length - cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side + cls.max_assistant_turns = ( + config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns + ) + cls.max_parallel_calls = ( + config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls + ) + cls.max_tool_response_length = ( + config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length + ) + cls.tool_response_truncate_side = ( + config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side + ) tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path - tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] + tool_list = ( + initialize_tools_from_config(tool_config_path) if tool_config_path else [] + ) cls.tools = {tool.name: tool for tool in tool_list} - cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] - cls.tool_parser = cls.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format) + cls.tool_schemas = [ + tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) + for tool in tool_list + ] + cls.tool_parser = cls.get_tool_parser( + config.actor_rollout_ref.rollout.multi_turn.format + ) print(f"Initialized tools: {cls.tools}") cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length cls.response_length = config.actor_rollout_ref.rollout.response_length - cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True) + cls.system_prompt = tokenizer.apply_chat_template( + [{}], add_generation_prompt=False, tokenize=True + ) @rollout_trace_op - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: + async def run( + self, messages: list[dict[str, Any]], sampling_params: dict[str, Any] + ) -> AgentLoopOutput: metrics = {} request_id = uuid4().hex prompt_ids = await self.loop.run_in_executor( None, lambda: self.tokenizer.apply_chat_template( - messages, tools=self.tool_schemas, add_generation_prompt=True, tokenize=True + messages, + tools=self.tool_schemas, + add_generation_prompt=True, + tokenize=True, ), ) response_mask = [] @@ -132,7 +161,9 @@ async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, A while True: with simple_timer("generate_sequences", metrics): response_ids = await self.server_manager.generate( - request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params + request_id=request_id, + prompt_ids=prompt_ids, + sampling_params=sampling_params, ) prompt_ids += response_ids response_mask += [1] * len(response_ids) @@ -214,12 +245,20 @@ async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]: if len(tool_response) > self.max_tool_response_length: if self.tool_response_truncate_side == "left": - tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)" + tool_response = ( + tool_response[: self.max_tool_response_length] + "...(truncated)" + ) elif self.tool_response_truncate_side == "right": - tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :] + tool_response = ( + "(truncated)..." + tool_response[-self.max_tool_response_length :] + ) else: length = self.max_tool_response_length // 2 - tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:] + tool_response = ( + tool_response[:length] + + "...(truncated)..." + + tool_response[-length:] + ) return { "role": "tool", diff --git a/Agent0/executor_train/verl/verl/experimental/dynamic_dataset/dynamicgen_dataset.py b/Agent0/executor_train/verl/verl/experimental/dynamic_dataset/dynamicgen_dataset.py index a9532aa..4348a40 100644 --- a/Agent0/executor_train/verl/verl/experimental/dynamic_dataset/dynamicgen_dataset.py +++ b/Agent0/executor_train/verl/verl/experimental/dynamic_dataset/dynamicgen_dataset.py @@ -80,9 +80,9 @@ def __init__( ): super().__init__(data_files, tokenizer, config, processor) self.datagen: AbstractDataGenerator = config.datagen - assert "datagen" in config and config.datagen.get("path", None) is not None, ( - f"datagen path is not set in config: {config}" - ) + assert ( + "datagen" in config and config.datagen.get("path", None) is not None + ), f"datagen path is not set in config: {config}" # Dynamically load the custom datagen class datagen_cls = load_extern_type(config.datagen.path, config.datagen.name) diff --git a/Agent0/executor_train/verl/verl/interactions/base.py b/Agent0/executor_train/verl/verl/interactions/base.py index 7c5d200..99f2d77 100644 --- a/Agent0/executor_train/verl/verl/interactions/base.py +++ b/Agent0/executor_train/verl/verl/interactions/base.py @@ -20,9 +20,13 @@ class BaseInteraction: def __init__(self, config: dict[str, Any]): self.config = config - self.name: str = config.get("name", "interaction_agent") # More general agent default role name + self.name: str = config.get( + "name", "interaction_agent" + ) # More general agent default role name - async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str: + async def start_interaction( + self, instance_id: Optional[str] = None, **kwargs + ) -> str: """Create a tool instance. Args: @@ -38,7 +42,9 @@ async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) - async def generate_response( self, instance_id: str, messages: list[dict[str, Any]], **kwargs - ) -> tuple[bool, str, float, dict[str, Any]]: # More clear response generation method + ) -> tuple[ + bool, str, float, dict[str, Any] + ]: # More clear response generation method """ Generates a response for the current turn of interaction. Returns a tuple containing: @@ -51,7 +57,12 @@ async def generate_response( response_content: str = "Your current result seems acceptable." current_turn_score: float = 0.8 additional_data: dict[str, Any] = {} - return should_terminate_sequence, response_content, current_turn_score, additional_data + return ( + should_terminate_sequence, + response_content, + current_turn_score, + additional_data, + ) async def calculate_score(self) -> float: # More clear score calculation method """ @@ -63,7 +74,9 @@ async def calculate_score(self) -> float: # More clear score calculation method score = 0.0 return score - async def finalize_interaction(self) -> None: # More clear interaction end and resource release method + async def finalize_interaction( + self, + ) -> None: # More clear interaction end and resource release method """ Finalizes the interaction session and releases any associated state or resources. Simulates: release state diff --git a/Agent0/executor_train/verl/verl/interactions/gsm8k_interaction.py b/Agent0/executor_train/verl/verl/interactions/gsm8k_interaction.py index 365cbb9..92a1cfd 100644 --- a/Agent0/executor_train/verl/verl/interactions/gsm8k_interaction.py +++ b/Agent0/executor_train/verl/verl/interactions/gsm8k_interaction.py @@ -41,7 +41,10 @@ def __init__(self, config: dict): self._instance_dict = {} async def start_interaction( - self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs + self, + instance_id: Optional[str] = None, + ground_truth: Optional[str] = None, + **kwargs ) -> str: if instance_id is None: instance_id = str(uuid4()) diff --git a/Agent0/executor_train/verl/verl/interactions/utils/interaction_registry.py b/Agent0/executor_train/verl/verl/interactions/utils/interaction_registry.py index df747af..ed080a0 100644 --- a/Agent0/executor_train/verl/verl/interactions/utils/interaction_registry.py +++ b/Agent0/executor_train/verl/verl/interactions/utils/interaction_registry.py @@ -65,13 +65,17 @@ def initialize_interactions_from_config(interaction_config_file): class_simple_name = cls_name.split(".")[-1] # Remove "Interaction" suffix if present, otherwise use full class name if class_simple_name.endswith("Interaction"): - name = class_simple_name[:-11].lower() # Remove "Interaction" (11 chars) + name = class_simple_name[ + :-11 + ].lower() # Remove "Interaction" (11 chars) else: name = class_simple_name.lower() # Check for duplicate names if name in interaction_map: - raise ValueError(f"Duplicate interaction name '{name}' found. Each interaction must have a unique name.") + raise ValueError( + f"Duplicate interaction name '{name}' found. Each interaction must have a unique name." + ) # Inject the name into the config config["name"] = name diff --git a/Agent0/executor_train/verl/verl/model_merger/base_model_merger.py b/Agent0/executor_train/verl/verl/model_merger/base_model_merger.py index f13f5fb..81276dd 100644 --- a/Agent0/executor_train/verl/verl/model_merger/base_model_merger.py +++ b/Agent0/executor_train/verl/verl/model_merger/base_model_merger.py @@ -33,13 +33,24 @@ def parse_args(): parser = argparse.ArgumentParser(description="verl model merger") - subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") + subparsers = parser.add_subparsers( + dest="operation", required=True, help="Specify 'merge' or 'test' operation." + ) base_op_parser = argparse.ArgumentParser(add_help=False) base_op_parser.add_argument( - "--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model" + "--backend", + type=str, + required=True, + choices=["fsdp", "megatron"], + help="The backend of the model", + ) + base_op_parser.add_argument( + "--local_dir", + type=str, + default=None, + help="Path to the saved model checkpoints.", ) - base_op_parser.add_argument("--local_dir", type=str, default=None, help="Path to the saved model checkpoints.") base_op_parser.add_argument( "--tie-word-embedding", action="store_true", @@ -57,22 +68,37 @@ def parse_args(): "fit into GPU memory during initialization.", ) - merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") + merge_parser = subparsers.add_parser( + "merge", parents=[base_op_parser], help="Merge model checkpoints and save." + ) merge_parser.add_argument( - "--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model" + "--target_dir", + default="tmp", + type=str, + help="Directory to save the merged huggingface model", ) merge_parser.add_argument( - "--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model" + "--hf_upload_path", + default=None, + type=str, + help="Hugging Face repository ID to upload the model", ) merge_parser.add_argument( - "--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository" + "--private", + action="store_true", + help="Whether to upload the model to a private Hugging Face repository", ) test_parser = subparsers.add_parser( - "test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model" + "test", + parents=[base_op_parser], + help="Test merged model against a reference Hugging Face model", ) test_parser.add_argument( - "--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing" + "--test_hf_dir", + type=str, + required=True, + help="Path to the reference Hugging Face model directory for testing", ) args = parser.parse_args() @@ -171,7 +197,9 @@ def get_transformers_auto_model_class(self): elif "ForConditionalGeneration" in self.model_config.architectures[0]: return AutoModelForVision2Seq - raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") + raise NotImplementedError( + f"Unknown architecture {self.model_config.architectures}" + ) def patch_model_generation_config(self, model): """ @@ -182,7 +210,9 @@ def patch_model_generation_config(self, model): """ if model.can_generate(): try: - model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) + model.generation_config = GenerationConfig.from_pretrained( + self.hf_model_config_path + ) except OSError: print( f"Warning: Generation config file not found in {self.hf_model_config_path}, using a " @@ -227,13 +257,19 @@ def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): "target_modules": list(target_modules), } peft_config = peft.LoraConfig(**peft_dict).to_dict() - peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None - peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None + peft_config["task_type"] = ( + peft_config["task_type"].value if peft_config["task_type"] else None + ) + peft_config["peft_type"] = ( + peft_config["peft_type"].value if peft_config["peft_type"] else None + ) peft_config["target_modules"] = list(peft_config["target_modules"]) lora_path = os.path.join(self.config.target_dir, "lora_adapter") os.makedirs(lora_path, exist_ok=True) - with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f: + with open( + os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8" + ) as f: json.dump(peft_config, f, ensure_ascii=False, indent=4) save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors")) @@ -250,7 +286,9 @@ def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): auto_model_class = self.get_transformers_auto_model_class() with init_empty_weights(): - model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16) + model = auto_model_class.from_config( + self.model_config, torch_dtype=torch.bfloat16 + ) model.to_empty(device="cpu") model = self.patch_model_generation_config(model) @@ -280,7 +318,11 @@ def upload_to_huggingface(self): api = HfApi() try: # Attempt to create repository - api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) + api.create_repo( + repo_id=self.config.hf_upload_path, + private=self.config.private, + exist_ok=True, + ) except HfHubHTTPError as e: # Handle authentication/API errors if e.response.status_code == 401: @@ -288,24 +330,42 @@ def upload_to_huggingface(self): "Hugging Face authentication failed. Verify your token is valid and has write permissions." ) from e elif e.response.status_code == 404: - raise RepositoryNotFoundError(f"Repository path not found: {self.config.hf_upload_path}") from e + raise RepositoryNotFoundError( + f"Repository path not found: {self.config.hf_upload_path}" + ) from e else: - raise ConnectionError(f"Failed to create repository ({e.response.status_code}): {e}") from e + raise ConnectionError( + f"Failed to create repository ({e.response.status_code}): {e}" + ) from e except requests.exceptions.ConnectionError as e: - raise ConnectionError("Network connection failed. Check your internet connection.") from e + raise ConnectionError( + "Network connection failed. Check your internet connection." + ) from e try: # Attempt folder upload - api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") + api.upload_folder( + folder_path=self.config.target_dir, + repo_id=self.config.hf_upload_path, + repo_type="model", + ) except HfHubHTTPError as e: if e.response.status_code == 401: - raise PermissionError("Authentication failed during upload. Token may have expired.") from e + raise PermissionError( + "Authentication failed during upload. Token may have expired." + ) from e else: - raise RuntimeError(f"Upload failed ({e.response.status_code}): {e}") from e + raise RuntimeError( + f"Upload failed ({e.response.status_code}): {e}" + ) from e except requests.exceptions.ConnectionError as e: - raise ConnectionError("Network interruption during upload. Try again with stable connection.") from e + raise ConnectionError( + "Network interruption during upload. Try again with stable connection." + ) from e except OSError as e: - raise FileNotFoundError(f"Local folder error: {self.config.target_dir} - {str(e)}") from e + raise FileNotFoundError( + f"Local folder error: {self.config.target_dir} - {str(e)}" + ) from e except Exception as e: raise RuntimeError(f"Unexpected error during upload: {str(e)}") from e @@ -315,4 +375,6 @@ def merge_and_save(self): @abstractmethod def cleanup(self): - raise NotImplementedError("Subclasses should implement this method to clean up resources if needed") + raise NotImplementedError( + "Subclasses should implement this method to clean up resources if needed" + ) diff --git a/Agent0/executor_train/verl/verl/model_merger/fsdp_model_merger.py b/Agent0/executor_train/verl/verl/model_merger/fsdp_model_merger.py index 7853b2b..1d1df99 100644 --- a/Agent0/executor_train/verl/verl/model_merger/fsdp_model_merger.py +++ b/Agent0/executor_train/verl/verl/model_merger/fsdp_model_merger.py @@ -93,7 +93,9 @@ def _load_rank_zero_state_dict(self, world_size: int) -> dict: weights_only=False, ) - def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: + def _extract_device_mesh_info( + self, state_dict: dict, world_size: int + ) -> tuple[np.ndarray, tuple[str, ...]]: """ Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. If no DTensor is found, infers a simple FSDP mesh based on world_size. @@ -117,7 +119,10 @@ def _calculate_shard_configuration( self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...] ) -> tuple[int, tuple[int, ...]]: """Calculates the total number of shards and the shape of the device mesh.""" - assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" + assert mesh_dim_names in ( + ("fsdp",), + ("ddp", "fsdp"), + ), f"Unsupported mesh_dim_names {mesh_dim_names}" if "tp" in mesh_dim_names: # TODO: "tp" is not supported yet due to the above assert @@ -129,7 +134,9 @@ def _calculate_shard_configuration( return total_shards, mesh_shape - def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: + def _merge_by_placement( + self, tensors: list[torch.Tensor], placement: Placement + ) -> torch.Tensor: """Merges a list of tensors based on their DTensor placement""" if placement.is_replicate(): return tensors[0] @@ -141,19 +148,31 @@ def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) raise NotImplementedError(f"Unsupported placement: {placement}") def _load_and_merge_state_dicts( - self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...] + self, + world_size: int, + total_shards: int, + mesh_shape: tuple[int, ...], + mesh_dim_names: tuple[str, ...], ) -> dict[str, torch.Tensor]: model_state_dict_lst = [None] * total_shards def process_one_shard(rank: int, model_state_dict_lst: list): - model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" + model_path = ( + Path(self.config.local_dir) + / f"model_world_size_{world_size}_rank_{rank}.pt" + ) state_dict = torch.load(model_path, map_location="cpu", weights_only=False) model_state_dict_lst[rank] = state_dict return state_dict with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] - for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): + futures = [ + executor.submit(process_one_shard, rank, model_state_dict_lst) + for rank in range(total_shards) + ] + for future in tqdm( + futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards + ): future.result() # Merge state dicts from all shards @@ -207,13 +226,19 @@ def merge_and_save(self): world_size = self._get_world_size() rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) - mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) + mesh, mesh_dim_names = self._extract_device_mesh_info( + rank_zero_state_dict, world_size + ) print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) + total_shards, mesh_shape = self._calculate_shard_configuration( + mesh, mesh_dim_names + ) print(f"Processing model shards with {total_shards} {mesh_shape} in total") - merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) + merged_state_dict = self._load_and_merge_state_dicts( + world_size, total_shards, mesh_shape, mesh_dim_names + ) if self.config.operation == "test": if not self.config.test_hf_dir: @@ -229,7 +254,9 @@ def merge_and_save(self): def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): auto_model_class = self.get_transformers_auto_model_class() - hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) + hf_model = auto_model_class.from_pretrained( + self.config.test_hf_dir, torch_dtype=torch.bfloat16 + ) hf_state_dict = hf_model.state_dict() del hf_model @@ -237,27 +264,35 @@ def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): collected_keys = set(state_dict.keys()) missing_keys = hf_model_keys - collected_keys - assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" + assert ( + len(missing_keys) == 0 + ), f"Missing keys in collected state dict: {list(sorted(missing_keys))}" extra_keys = collected_keys - hf_model_keys - assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" + assert ( + len(extra_keys) == 0 + ), f"Extra keys in collected state dict: {list(sorted(extra_keys))}" for key in hf_model_keys: hf_shape = hf_state_dict[key].shape collected_shape = state_dict[key].shape - assert hf_shape == collected_shape, ( - f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" - ) + assert ( + hf_shape == collected_shape + ), f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" hf_dtype = hf_state_dict[key].dtype collected_dtype = state_dict[key].dtype - assert hf_dtype == collected_dtype, ( - f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" - ) + assert ( + hf_dtype == collected_dtype + ), f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" - torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) + torch.testing.assert_close( + hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6 + ) - print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") + print( + "FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager." + ) def cleanup(self): """Cleanup temporary files if needed.""" diff --git a/Agent0/executor_train/verl/verl/model_merger/megatron_model_merger.py b/Agent0/executor_train/verl/verl/model_merger/megatron_model_merger.py index c40bdf7..c94fd64 100644 --- a/Agent0/executor_train/verl/verl/model_merger/megatron_model_merger.py +++ b/Agent0/executor_train/verl/verl/model_merger/megatron_model_merger.py @@ -226,7 +226,11 @@ def _check_megatron_state_key(self, key: str) -> bool: ) def _split_tensors( - self, key: str, tensor: torch.Tensor, config: PretrainedConfig, is_value_model: bool = False + self, + key: str, + tensor: torch.Tensor, + config: PretrainedConfig, + is_value_model: bool = False, ) -> list[torch.Tensor]: """ Splits a tensor into multiple tensors based on the name. @@ -248,9 +252,9 @@ def _split_tensors( q_lst, k_lst, v_lst = [], [], [] assert config.num_attention_heads % config.num_key_value_heads == 0 num_q_per_kv = config.num_attention_heads // config.num_key_value_heads - assert tensor.shape[0] % (num_q_per_kv + 2) == 0, ( - f"Tensor shape {tensor.shape} is not divisible by {num_q_per_kv + 2}" - ) + assert ( + tensor.shape[0] % (num_q_per_kv + 2) == 0 + ), f"Tensor shape {tensor.shape} is not divisible by {num_q_per_kv + 2}" kv_size = tensor.shape[0] // (num_q_per_kv + 2) split_size = [kv_size * num_q_per_kv, kv_size, kv_size] @@ -266,11 +270,17 @@ def _split_tensors( k_lst.append(k) v_lst.append(v) - return [torch.cat(q_lst, dim=0), torch.cat(k_lst, dim=0), torch.cat(v_lst, dim=0)] + return [ + torch.cat(q_lst, dim=0), + torch.cat(k_lst, dim=0), + torch.cat(v_lst, dim=0), + ] else: return [tensor] - def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + def _merge_state_dicts( + self, model_state_dict_list: list[dict[str, Any]] + ) -> dict[str, torch.Tensor]: state_dict = {} layers_cum = 0 @@ -281,12 +291,16 @@ def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dic if "extra_state" in key: continue if self.config.tie_word_embedding and ("output_layer" in key): - print("skip lm_head and reward_head loading because of tie_word_embeddings") + print( + "skip lm_head and reward_head loading because of tie_word_embeddings" + ) continue self._check_megatron_state_key(key) hf_name = self._replace_name(key, self.params_mapping) - assert hf_name is not None, f"Failed to convert layer name [{key}] from megatron to huggingface." + assert ( + hf_name is not None + ), f"Failed to convert layer name [{key}] from megatron to huggingface." if "model.layers." in hf_name: local_layer_no = int(hf_name.split(".")[2]) layers_handled = max(local_layer_no, layers_handled) @@ -295,11 +309,17 @@ def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dic new_key_list[2] = str(global_layer_no) hf_name = ".".join(new_key_list) else: - warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2) + warnings.warn( + f"hf_name {hf_name} will not be fixed with layer number", + stacklevel=2, + ) tensor = model_state_dict[key] split_tensor = self._split_tensors( - key, tensor, self.hf_config, is_value_model=self.config.is_value_model + key, + tensor, + self.hf_config, + is_value_model=self.config.is_value_model, ) if len(split_tensor) == 1: @@ -313,7 +333,9 @@ def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dic state_dict[hf_name.replace("gate_up", "gate")] = split_tensor[0] state_dict[hf_name.replace("gate_up", "up")] = split_tensor[1] shape_info = ( - split_tensor.shape if isinstance(split_tensor, torch.Tensor) else [t.shape for t in split_tensor] + split_tensor.shape + if isinstance(split_tensor, torch.Tensor) + else [t.shape for t in split_tensor] ) print(f"converted {key} to {hf_name} with shape {shape_info}") @@ -361,7 +383,9 @@ def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): raise RuntimeError(f"key: {name} not exist in state_dict") param = ref_state_dict[name] assert loaded_weight.dtype == param.dtype - torch.testing.assert_close(loaded_weight.to("cpu"), param, atol=1e-2, rtol=5e-2) + torch.testing.assert_close( + loaded_weight.to("cpu"), param, atol=1e-2, rtol=5e-2 + ) def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str: for m_name, v_name in name_mapping.items(): diff --git a/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_loader.py b/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_loader.py index dafecfd..b4557b0 100644 --- a/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_loader.py +++ b/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_loader.py @@ -41,7 +41,8 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + + pp_rank_idx * num_layers_per_model ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( @@ -53,7 +54,12 @@ def _megatron_calc_layer_map(config): def load_state_dict_to_megatron_llama( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False + state_dict, + wrapped_models, + config, + params_dtype, + is_value_model=False, + tie_word_embeddings=False, ): """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP @@ -72,7 +78,9 @@ def _get_gpt_model(model): def fetch_params(module): for param in module.parameters(): torch.distributed.fetch( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + param.data, + src=mpu.get_data_parallel_src_rank(), + group=mpu.get_data_parallel_group(), ) dp_rank = mpu.get_data_parallel_rank() @@ -91,7 +99,9 @@ def fetch_params(module): assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + assert ( + num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + ), ( f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" ) @@ -109,7 +119,9 @@ def _fetch_tensor(tensor, name) -> torch.Tensor: if tensor is not None: tensor.data.copy_(state_dict[name]) - def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + def _fetch_tp_shard_tensor_vocab( + tensor, name, chunk_dim=0, mutate_func=None + ) -> torch.Tensor: """fetch tensor in tp shards""" nonlocal state_dict tp_rank = mpu.get_tensor_model_parallel_rank() @@ -125,7 +137,9 @@ def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> else: print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + def _fetch_tp_shard_tensor( + tensor, name, chunk_dim=0, mutate_func=None + ) -> torch.Tensor: """fetch tensor in tp shards""" nonlocal state_dict tp_rank = mpu.get_tensor_model_parallel_rank() @@ -151,21 +165,30 @@ def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + config.intermediate_size * 2, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) + gate_weight_tp = gate_weight[ + i * intermediate_size_tp : (i + 1) * intermediate_size_tp + ] + up_weight_tp = up_weight[ + i * intermediate_size_tp : (i + 1) * intermediate_size_tp + ] + new_gate_up_weight[ + intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1) + ].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) if tensor is not None: tensor.data.copy_(tensor_chunk[tp_rank]) else: - print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") + print( + f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading" + ) def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: """fetch tensor in tp shards across mp_group""" @@ -185,28 +208,42 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size total_size = q_size_tp + 2 * kv_size_tp new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + total_size * tp_size, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) else: q_size_tp = config.hidden_size // tp_size kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + total_size * tp_size, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + start_idx = ( + i * config.num_key_value_heads // tp_size * hidden_size_per_head + ) + end_idx = ( + i * config.num_key_value_heads // tp_size + 1 + ) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) if tensor is not None: @@ -235,9 +272,10 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: for vpp_rank in range(vpp_size): num_layer_vpp_chunk = num_layer_per_pp // vpp_size num_layer_this_model = num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( - mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk - ) + offset = vpp_rank * ( + config.num_hidden_layers + // mpu.get_virtual_pipeline_model_parallel_world_size() + ) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk) layer_list.extend(list(range(offset, offset + num_layer_this_model))) else: num_layer_this_model = num_layer_per_pp @@ -271,7 +309,11 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: ) _fetch_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.post_attention_layernorm.weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.post_attention_layernorm.weight", ) @@ -300,10 +342,16 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: lm_head_weight = gpt_model_module.lm_head.weight if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + if ( + "lm_head.weight" in state_dict + and state_dict["lm_head.weight"].shape[0] == 1 + ): _fetch_tensor(lm_head_weight, "lm_head.weight") print_rank_0("load lm_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + elif ( + "reward_head.weight" in state_dict + and state_dict["reward_head.weight"].shape[0] == 1 + ): _fetch_tensor(lm_head_weight, "reward_head.weight") print_rank_0("load lm_head from value_head weight") else: @@ -314,4 +362,6 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: dist.barrier() get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") + print_rank_0( + f"loading megatron ckpt done, time elapsed {time.time() - start_time}s" + ) diff --git a/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py b/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py index 2f65bc6..d5be6f9 100644 --- a/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py +++ b/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py @@ -41,7 +41,8 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + + pp_rank_idx * num_layers_per_model ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( @@ -53,7 +54,12 @@ def _megatron_calc_layer_map(config): def load_state_dict_to_megatron_llama( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False + state_dict, + wrapped_models, + config, + params_dtype, + is_value_model=False, + tie_word_embeddings=False, ): """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP @@ -72,7 +78,9 @@ def _get_gpt_model(model): def broadcast_params(module): for param in module.parameters(): torch.distributed.broadcast( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + param.data, + src=mpu.get_data_parallel_src_rank(), + group=mpu.get_data_parallel_group(), ) dp_rank = mpu.get_data_parallel_rank() @@ -91,7 +99,9 @@ def broadcast_params(module): assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + assert ( + num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + ), ( f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" ) @@ -137,7 +147,9 @@ def _broadcast_tensor(tensor, name) -> torch.Tensor: tensor.data.copy_(weight) dist.broadcast(tensor, src=0, group=mp_group) - def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + def _broadcast_tp_shard_tensor_vocab( + tensor, name, chunk_dim=0, mutate_func=None + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -173,10 +185,12 @@ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None requires_grad=False, ) else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert ( + tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -185,7 +199,9 @@ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None if (i == tp_rank) and (tensor is not None): tensor.data.copy_(sync_tensor) - def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + def _broadcast_tp_shard_tensor( + tensor, name, chunk_dim=0, mutate_func=None + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -220,10 +236,12 @@ def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> t requires_grad=False, ) else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert ( + tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -243,15 +261,22 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + config.intermediate_size * 2, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) + gate_weight_tp = gate_weight[ + i * intermediate_size_tp : (i + 1) * intermediate_size_tp + ] + up_weight_tp = up_weight[ + i * intermediate_size_tp : (i + 1) * intermediate_size_tp + ] + new_gate_up_weight[ + intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1) + ].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -263,7 +288,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens chunk_shape = obj_list[0] if chunk_shape is None: # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + print_rank_0( + f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading" + ) return if tensor is None: @@ -278,7 +305,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " f"{tensor.shape} != {chunk_shape}" ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False + ) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -295,7 +324,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens tp_size = mpu.get_tensor_model_parallel_world_size() if torch.distributed.get_rank() == 0: - assert q_name in state_dict and k_name in state_dict and v_name in state_dict + assert ( + q_name in state_dict and k_name in state_dict and v_name in state_dict + ) full_weight_q = state_dict[q_name] full_weight_k = state_dict[k_name] full_weight_v = state_dict[v_name] @@ -304,10 +335,15 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens if config.num_key_value_heads >= tp_size: q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + kv_size_tp = ( + hidden_size_per_head * config.num_key_value_heads // tp_size + ) total_size = q_size_tp + 2 * kv_size_tp new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + total_size * tp_size, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] @@ -322,12 +358,19 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + total_size * tp_size, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + start_idx = ( + i * config.num_key_value_heads // tp_size * hidden_size_per_head + ) + end_idx = ( + i * config.num_key_value_heads // tp_size + 1 + ) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( @@ -344,7 +387,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens chunk_shape = obj_list[0] if chunk_shape is None: # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + print_rank_0( + f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading" + ) return if tensor is None: @@ -355,10 +400,12 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens requires_grad=False, ) else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + assert ( + tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -375,7 +422,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens embed_tokens_weight = None if pp_rank == 0: embed_tokens_weight = gpt_model_module.model.embed_tokens.weight - _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + _broadcast_tp_shard_tensor_vocab( + embed_tokens_weight, "model.embed_tokens.weight" + ) # Transformer layers # ------------------- @@ -395,7 +444,11 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens ) _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.self_attn.qkv_proj.weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.self_attn.q_proj.weight", f"{layer_name}.self_attn.k_proj.weight", f"{layer_name}.self_attn.v_proj.weight", @@ -408,7 +461,11 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens ) _broadcast_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.post_attention_layernorm.weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.post_attention_layernorm.weight", ) @@ -438,10 +495,16 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens lm_head_weight = gpt_model_module.lm_head.weight if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + if ( + "lm_head.weight" in state_dict + and state_dict["lm_head.weight"].shape[0] == 1 + ): _broadcast_tensor(lm_head_weight, "lm_head.weight") print_rank_0("load lm_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + elif ( + "reward_head.weight" in state_dict + and state_dict["reward_head.weight"].shape[0] == 1 + ): _broadcast_tensor(lm_head_weight, "reward_head.weight") print_rank_0("load lm_head from value_head weight") else: @@ -455,4 +518,6 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tens broadcast_params(wrapped_model) get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") + print_rank_0( + f"loading megatron ckpt done, time elapsed {time.time() - start_time}s" + ) diff --git a/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_saver.py b/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_saver.py index 595efcd..2da7855 100644 --- a/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_saver.py +++ b/Agent0/executor_train/verl/verl/models/llama/megatron/checkpoint_utils/llama_saver.py @@ -32,9 +32,9 @@ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int tp_size = mpu.get_tensor_model_parallel_world_size() dp_size = mpu.get_data_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() - assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( - f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" - ) + assert ( + tp_size * dp_size * pp_size == torch.distributed.get_world_size() + ), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" # We only support TP-DP-PP grouping, for correctness when resharding return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank @@ -58,7 +58,8 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + + pp_rank_idx * num_layers_per_model ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( @@ -69,7 +70,9 @@ def _megatron_calc_layer_map(config): return layer_map -def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): +def merge_megatron_ckpt_llama( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): """Merge sharded parameters of a Megatron module into a merged checkpoint. Args: @@ -111,10 +114,10 @@ def _get_gpt_model(model): for i, wrapped_model in enumerate(wrapped_models): models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].model.layers) == num_layers_per_model, ( - "len model layers {} not equal to num_layers_per_model {}".format( - len(models[i].model.layers), num_layers_per_model - ) + assert ( + len(models[i].model.layers) == num_layers_per_model + ), "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].model.layers), num_layers_per_model ) state_dict = dict() @@ -165,7 +168,9 @@ def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: if torch.distributed.get_rank() == 0: state_dict[name] = _get_cpu_tensor(weight) - def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + def _broadcast_tp_shard_tensor( + tensor, name, src_pp_rank, concat_dim=0, mutate_func=None + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -192,8 +197,14 @@ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_f chunk_tensors = [None] * tp_size for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + cur_src_rank = _megatron_calc_global_rank( + tp_rank=i, dp_rank=0, pp_rank=src_pp_rank + ) + sync_tensor = ( + tensor + if torch.distributed.get_rank() == cur_src_rank + else buffer_tensor + ) dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) if torch.distributed.get_rank() == 0: @@ -205,7 +216,9 @@ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_f full_tensor = mutate_func(full_tensor) state_dict[name] = full_tensor - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + def _broadcast_tp_shard_tensor_gate_up( + tensor, gate_name, up_name, src_pp_rank + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -219,7 +232,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) chunk_shape = obj_list[0] if chunk_shape is None: # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + print_rank_0( + f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting" + ) return buffer_tensor = torch.empty( @@ -232,8 +247,14 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) chunk_tensors = [None] * tp_size for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + cur_src_rank = _megatron_calc_global_rank( + tp_rank=i, dp_rank=0, pp_rank=src_pp_rank + ) + sync_tensor = ( + tensor + if torch.distributed.get_rank() == cur_src_rank + else buffer_tensor + ) dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) if torch.distributed.get_rank() == 0: @@ -245,7 +266,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) gate_weight_list = [] up_weight_list = [] for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_up_weight_tp = full_tensor[ + intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1) + ] gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] gate_weight_list.append(gate_weight_tp) @@ -281,8 +304,14 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): chunk_tensors = [None] * tp_size for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + cur_src_rank = _megatron_calc_global_rank( + tp_rank=i, dp_rank=0, pp_rank=src_pp_rank + ) + sync_tensor = ( + tensor + if torch.distributed.get_rank() == cur_src_rank + else buffer_tensor + ) dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) if torch.distributed.get_rank() == 0: @@ -297,7 +326,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): if config.num_key_value_heads >= tp_size: q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + kv_size_tp = ( + hidden_size_per_head * config.num_key_value_heads // tp_size + ) total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): qkv_part = full_tensor[i * total_size : (i + 1) * total_size] @@ -406,23 +437,32 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): if is_value_model: if pp_rank == pp_size - 1: - print(f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}") + print( + f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}" + ) _broadcast_tensor( gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, "lm_head.weight", src_pp_rank=pp_size - 1, ) _broadcast_tensor( - gpt_model_module.reward_head.weight - if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None - else None, + ( + gpt_model_module.reward_head.weight + if pp_rank == pp_size - 1 + and getattr(gpt_model_module, "reward_weight", None) is not None + else None + ), "reward_head.weight", src_pp_rank=pp_size - 1, ) else: _broadcast_tp_shard_tensor( - getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, + ( + getattr(gpt_model_module.lm_head, "weight", None) + if pp_rank == pp_size - 1 + else None + ), "lm_head.weight", src_pp_rank=pp_size - 1, ) diff --git a/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_attention.py b/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_attention.py index e8aacbd..96129da 100644 --- a/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_attention.py +++ b/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_attention.py @@ -42,17 +42,23 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -74,13 +80,22 @@ def forward(self, x, seq_len=None): class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) t = t / self.scaling_factor freqs = torch.einsum("i,j->ij", t, self.inv_freq) @@ -93,7 +108,14 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) @@ -102,12 +124,17 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len > self.max_position_embeddings: base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -117,12 +144,20 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding): - def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, dim, config, max_position_embeddings=2048, base=10000, device=None + ): super().__init__(dim, max_position_embeddings, base, device) - self.factor = config.rope_scaling["factor"] # `8` in the original implementation - self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation - self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation + self.factor = config.rope_scaling[ + "factor" + ] # `8` in the original implementation + self.high_freq_factor = config.rope_scaling[ + "high_freq_factor" + ] # `1` in the original implementation + self.low_freq_factor = config.rope_scaling[ + "low_freq_factor" + ] # `4` in the original implementation self.old_context_len = config.rope_scaling[ "original_max_position_embeddings" ] # `8192` in the original implementation @@ -132,12 +167,16 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device wavelen = 2 * math.pi / self.inv_freq # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor - inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq) + inv_freq_llama = torch.where( + wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq + ) # otherwise: interpolate between the two, using a smooth factor smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / ( self.high_freq_factor - self.low_freq_factor ) - smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) @@ -145,7 +184,9 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), ) @@ -172,7 +213,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -193,9 +236,9 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): # assign values after tp tp_size = mpu.get_tensor_model_parallel_world_size() - assert self.num_heads % tp_size == 0, ( - f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" - ) + assert ( + self.num_heads % tp_size == 0 + ), f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" assert self.num_key_value_heads % tp_size == 0, ( f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" f"{self.num_key_value_heads}, tp_size={tp_size}" @@ -255,7 +298,9 @@ def _init_rope(self): base=self.rope_theta, ) else: - rope_type_key = "type" if "type" in self.config.rope_scaling else "rope_type" + rope_type_key = ( + "type" if "type" in self.config.rope_scaling else "rope_type" + ) scaling_type = self.config.rope_scaling[rope_type_key] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": @@ -283,7 +328,11 @@ def _init_rope(self): raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) def forward( self, @@ -293,20 +342,32 @@ def forward( ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + query_states, key_states, value_states = qkv.split( + [self.q_size, self.k_size, self.v_size], dim=-1 + ) - query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads_per_tp, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim + ).transpose(1, 2) kv_seq_len = key_states.shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): raise ValueError( @@ -322,7 +383,9 @@ def forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): @@ -352,7 +415,9 @@ def forward( def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): batch_size = position_ids.shape[0] - q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) + q = pad_input( + q, indices, batch_size, sequence_length + ) # (batch_size, seqlen, num_head, head_dim) k = pad_input(k, indices, batch_size, sequence_length) cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] @@ -369,10 +434,22 @@ def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_l # cos/sin shoudl be: (seq_length, rotary_dim / 2) def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): q_embed = apply_rotary_emb( - q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + q, + cos, + sin, + interleaved=False, + inplace=False, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, ) k_embed = apply_rotary_emb( - k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + k, + cos, + sin, + interleaved=False, + inplace=False, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, ) return q_embed, k_embed @@ -387,7 +464,9 @@ def forward( cu_seqlens: torch.Tensor = None, max_seqlen_in_batch: int = None, ): - total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel + total_nnz, _, _ = ( + hidden_states.size() + ) # This is the total_nnz padded after sequence parallel if self.megatron_config.sequence_parallel: total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() @@ -407,14 +486,28 @@ def forward( # Flash attention requires the input to have the shape # batch_size x seq_length x head_dime x hidden_dim # therefore we just need to keep the original shape - query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) - key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) - value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + query_states = query_states.view( + total_nnz, self.num_heads_per_tp, self.head_dim + ) + key_states = key_states.view( + total_nnz, self.num_key_value_heads_per_tp, self.head_dim + ) + value_states = value_states.view( + total_nnz, self.num_key_value_heads_per_tp, self.head_dim + ) cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) - cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half + cos, sin = ( + cos[:, : cos.shape[1] // 2], + sin[:, : sin.shape[1] // 2], + ) # flash attn only needs half query_states, key_states = apply_rotary_pos_emb_rmpad_flash( - query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch + query_states, + key_states, + cos, + sin, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen_in_batch, ) # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, # position_ids, indices, @@ -449,12 +542,16 @@ def forward( ) attn_output_unpad = attn_output_unpad.to(input_dtype) - attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() + attn_output_unpad = attn_output_unpad.reshape( + total_nnz, 1, self.hidden_size_per_tp + ).contiguous() # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled # Here we need to repad if self.megatron_config.sequence_parallel: - attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) + attn_output_unpad = F.pad( + attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad) + ) attn_output_unpad = self.o_proj(attn_output_unpad)[0] return attn_output_unpad diff --git a/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_decoder.py b/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_decoder.py index f46e945..6253605 100644 --- a/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_decoder.py +++ b/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_decoder.py @@ -33,12 +33,16 @@ class ParallelLlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): + def __init__( + self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int + ): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) self.layer_idx = layer_idx self.hidden_size = config.hidden_size - self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config) + self.self_attn = ParallelLlamaAttention( + config=config, megatron_config=megatron_config + ) self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) @@ -49,7 +53,9 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -100,12 +106,16 @@ def forward( class ParallelLlamaDecoderLayerRmPad(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): + def __init__( + self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int + ): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) self.layer_idx = layer_idx self.hidden_size = config.hidden_size - self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config) + self.self_attn = ParallelLlamaAttentionRmPad( + config=config, megatron_config=megatron_config + ) self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) @@ -119,7 +129,9 @@ def forward( indices: torch.Tensor = None, cu_seqlens: int = None, max_seqlen_in_batch: int = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: residual = hidden_states # (total_nnz // sp, 1, hidden_size) hidden_states = self.input_layernorm(hidden_states) diff --git a/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_linear.py b/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_linear.py index 043726c..c2294ae 100644 --- a/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_linear.py +++ b/Agent0/executor_train/verl/verl/models/llama/megatron/layers/parallel_linear.py @@ -102,5 +102,7 @@ def forward( logits = super().forward(input_) logits = logits.float() if self.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + logits = tensor_parallel.gather_from_sequence_parallel_region( + logits, tensor_parallel_output_grad=False + ) return logits, None diff --git a/Agent0/executor_train/verl/verl/models/llama/megatron/modeling_llama_megatron.py b/Agent0/executor_train/verl/verl/models/llama/megatron/modeling_llama_megatron.py index ed5022e..16aec1f 100644 --- a/Agent0/executor_train/verl/verl/models/llama/megatron/modeling_llama_megatron.py +++ b/Agent0/executor_train/verl/verl/models/llama/megatron/modeling_llama_megatron.py @@ -33,7 +33,11 @@ from verl.utils.megatron import tensor_parallel as tp_utils from verl.utils.megatron_utils import TransformerConfig, convert_config -from .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm +from .layers import ( + ParallelLlamaDecoderLayer, + ParallelLlamaDecoderLayerRmPad, + ParallelLlamaRMSNorm, +) """ TODO: @@ -44,7 +48,9 @@ # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device +): """ Make causal mask used for bi-directional self-attention. """ @@ -68,7 +74,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) class ParallelLlamaModel(nn.Module): @@ -86,19 +94,28 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): self.vocab_size = config.vocab_size embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + assert embedding_kwargs.get( + "config", False + ), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + **embedding_kwargs, ) self.layers = nn.ModuleList( - [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] + [ + ParallelLlamaDecoderLayer(config, megatron_config) + for _ in range(config.num_hidden_layers) + ] ) self.norm = ParallelLlamaRMSNorm(config, megatron_config) # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds + ): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None @@ -111,11 +128,13 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask @@ -140,7 +159,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) # embed positions - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds + ) hidden_states = inputs_embeds @@ -236,14 +257,21 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() self.megatron_config = megatron_config if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + assert embedding_kwargs.get( + "config", False + ), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + **embedding_kwargs, ) self.layers = nn.ModuleList( - [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] + [ + ParallelLlamaDecoderLayerRmPad(config, megatron_config) + for _ in range(config.num_hidden_layers) + ] ) self.norm = ParallelLlamaRMSNorm(config, megatron_config) @@ -265,12 +293,16 @@ def forward( Returns: """ - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + inputs_embeds = self.embed_tokens( + input_ids + ) # (1, total_nnz) -> (1, total_nnz, hidden_size) # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) inputs_embeds = inputs_embeds.transpose(0, 1) if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region( + inputs_embeds + ) hidden_states = inputs_embeds for idx, decoder_layer in enumerate(self.layers): @@ -317,7 +349,9 @@ def _forward_head(self, hidden_states): # all_gather from sequence parallel region is performed inside lm_head logits = self.lm_head(hidden_states)[0] logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) + logits = tensor_parallel.gather_from_tensor_model_parallel_region( + logits + ) # (total_nnz_padded, 1, vocab_size) return logits def forward( @@ -388,7 +422,9 @@ def _init_head(self, config): if self.megatron_config is not None: assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + self.lm_head = nn.Linear( + in_features=config.hidden_size, out_features=1, bias=False + ) # lm_head is effectively the same as sequence parallel sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) @@ -396,7 +432,9 @@ def _forward_head(self, hidden_states): logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) logits = logits.float() if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + logits = tensor_parallel.gather_from_sequence_parallel_region( + logits, tensor_parallel_output_grad=False + ) return logits def forward( @@ -425,7 +463,13 @@ class ParallelLlamaModelRmPadPP(nn.Module): config: LlamaConfig """ - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process): + def __init__( + self, + config: LlamaConfig, + megatron_config: ModelParallelConfig, + pre_process, + post_process, + ): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) self.padding_idx = config.pad_token_id @@ -435,11 +479,15 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pr self.megatron_config = megatron_config embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + assert embedding_kwargs.get( + "config", False + ), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) if pre_process: self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + **embedding_kwargs, ) else: self.embed_tokens = None @@ -454,14 +502,18 @@ def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pr self.layers = nn.ModuleList() self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size self.num_layer_this_model = self.num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) + offset = vpp_rank * (config.num_hidden_layers // vpp_size) + ( + pp_rank * self.num_layer_vpp_chunk + ) else: self.num_layer_this_model = self.num_layer_per_pp offset = pp_rank * self.num_layer_per_pp self.layers = nn.ModuleList() for i in range(self.num_layer_this_model): - layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i) + layer = ParallelLlamaDecoderLayerRmPad( + config, megatron_config, layer_idx=offset + i + ) self.layers.add_module(f"{i}", layer) if post_process: @@ -498,14 +550,18 @@ def forward( """ if self.pre_process: - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + inputs_embeds = self.embed_tokens( + input_ids + ) # (1, total_nnz) -> (1, total_nnz, hidden_size) # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron # so need to deal with it by handle here: # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) inputs_embeds = inputs_embeds.transpose(0, 1) if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region( + inputs_embeds + ) hidden_states = inputs_embeds else: @@ -543,11 +599,14 @@ def __init__( self.config: TransformerConfig = convert_config(config, megatron_config) self.megatron_config = megatron_config self.model = ParallelLlamaModelRmPadPP( - config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process - ) - assert share_embeddings_and_output_weights is False, ( - "Llama Model not supports sharing embedding and output weights" + config, + megatron_config=megatron_config, + pre_process=pre_process, + post_process=post_process, ) + assert ( + share_embeddings_and_output_weights is False + ), "Llama Model not supports sharing embedding and output weights" self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.vocab_size = config.vocab_size self.pre_process = pre_process @@ -634,7 +693,9 @@ def forward( hidden_states = outputs # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) logits = self._forward_head(hidden_states) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + logits = torch.squeeze( + logits, dim=1 + ) # remove the artificial batch dimension # torch.Size([8, 32, 16]) # remove padding from sequence parallel if self.megatron_config.sequence_parallel: @@ -662,7 +723,9 @@ def _init_head(self, config): if self.megatron_config is not None: assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + self.lm_head = nn.Linear( + in_features=config.hidden_size, out_features=1, bias=False + ) # lm_head is effectively the same as sequence parallel sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) @@ -670,7 +733,9 @@ def _forward_head(self, hidden_states): logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) logits = logits.float() if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + logits = tensor_parallel.gather_from_sequence_parallel_region( + logits, tensor_parallel_output_grad=False + ) return logits def forward( @@ -680,7 +745,11 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> tuple | CausalLMOutputWithPast: - output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) if self.post_process: output.logits = torch.squeeze(output.logits, dim=-1) return output diff --git a/Agent0/executor_train/verl/verl/models/mcore/config_converter.py b/Agent0/executor_train/verl/verl/models/mcore/config_converter.py index 597afcd..58f72b7 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/config_converter.py +++ b/Agent0/executor_train/verl/verl/models/mcore/config_converter.py @@ -25,7 +25,9 @@ def _get_base_transformer_config( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs + hf_config: PretrainedConfig, + dtype: torch.dtype, + **override_transformer_config_kwargs, ) -> dict: """ Create a base TransformerConfig with common parameters across different model architectures. @@ -92,7 +94,10 @@ def _get_base_transformer_config( def _get_mla_transformer_config( - hf_config: PretrainedConfig, mla_rope_config: dict, dtype: torch.dtype, **override_transformer_config_kwargs + hf_config: PretrainedConfig, + mla_rope_config: dict, + dtype: torch.dtype, + **override_transformer_config_kwargs, ) -> dict: """ Create a MLATransformerConfig with common parameters across different model architectures. @@ -107,7 +112,9 @@ def _get_mla_transformer_config( Returns: MLATransformerConfig with common parameters """ - base_config = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, **override_transformer_config_kwargs) + base_config = _get_base_transformer_config( + hf_config=hf_config, dtype=dtype, **override_transformer_config_kwargs + ) mla_config = { # MLA specific parameters "q_lora_rank": hf_config.q_lora_rank, @@ -130,10 +137,16 @@ def _get_mla_transformer_config( def hf_to_mcore_config_dense( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs + hf_config: PretrainedConfig, + dtype: torch.dtype, + **override_transformer_config_kwargs, ) -> TransformerConfig: # for LlamaForCausalLM or Qwen2ForCausalLM - qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) + qkv_bias = ( + True + if "Qwen2ForCausalLM" in hf_config.architectures + else getattr(hf_config, "attention_bias", False) + ) qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False args: dict = _get_base_transformer_config( @@ -151,7 +164,9 @@ def hf_to_mcore_config_dense( def hf_to_mcore_config_qwen2moe( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs + hf_config: PretrainedConfig, + dtype: torch.dtype, + **override_transformer_config_kwargs, ) -> TransformerConfig: args: dict = _get_base_transformer_config( hf_config=hf_config, @@ -186,7 +201,9 @@ def hf_to_mcore_config_qwen2moe( def hf_to_mcore_config_mixtral( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs + hf_config: PretrainedConfig, + dtype: torch.dtype, + **override_transformer_config_kwargs, ) -> TransformerConfig: args: dict = _get_base_transformer_config( hf_config=hf_config, @@ -220,7 +237,9 @@ def hf_to_mcore_config_mixtral( def hf_to_mcore_config_qwen3moe( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs + hf_config: PretrainedConfig, + dtype: torch.dtype, + **override_transformer_config_kwargs, ) -> TransformerConfig: args: dict = _get_base_transformer_config( hf_config=hf_config, @@ -253,7 +272,9 @@ def hf_to_mcore_config_qwen3moe( def hf_to_mcore_config_dpskv3( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs + hf_config: PretrainedConfig, + dtype: torch.dtype, + **override_transformer_config_kwargs, ) -> MLATransformerConfig: # DeepseekV3ForCausalLM from megatron.core.transformer.enums import AttnBackend @@ -279,12 +300,12 @@ def hf_to_mcore_config_dpskv3( # disable MTP and quantization for now if "num_nextn_predict_layers" in hf_config: - assert hf_config.num_nextn_predict_layers == 0, ( - "MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0" - ) - assert "quantization_config" not in hf_config or not hf_config.quantization_config, ( - "quantization is not supported for now, please modify the config.json to remove quantization_config" - ) + assert ( + hf_config.num_nextn_predict_layers == 0 + ), "MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0" + assert ( + "quantization_config" not in hf_config or not hf_config.quantization_config + ), "quantization is not supported for now, please modify the config.json to remove quantization_config" args: dict = _get_mla_transformer_config( hf_config=hf_config, @@ -302,7 +323,8 @@ def hf_to_mcore_config_dpskv3( moe_router_enable_expert_bias=True, moe_router_topk=hf_config.num_experts_per_tok, num_moe_experts=hf_config.n_routed_experts, - moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts, + moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size + * hf_config.n_shared_experts, moe_aux_loss_coeff=getattr(hf_config, "aux_loss_alpha", 0.001), moe_router_load_balancing_type="seq_aux_loss", moe_shared_expert_overlap=True, @@ -335,7 +357,9 @@ def hf_to_mcore_config_dpskv3( def hf_to_mcore_config_qwen2_5_vl( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs + hf_config: PretrainedConfig, + dtype: torch.dtype, + **override_transformer_config_kwargs, ) -> TransformerConfig: # Qwen2_5_VLForConditionalGeneration @@ -354,7 +378,9 @@ def hf_to_mcore_config_qwen2_5_vl( def hf_to_mcore_config_llama4( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs + hf_config: PretrainedConfig, + dtype: torch.dtype, + **override_transformer_config_kwargs, ) -> TransformerConfig: # Llama4ForConditionalGeneration raise NotImplementedError("Llama4ForConditionalGeneration is not supported yet") diff --git a/Agent0/executor_train/verl/verl/models/mcore/loader.py b/Agent0/executor_train/verl/verl/models/mcore/loader.py index 659b4ba..9f2dad8 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/loader.py +++ b/Agent0/executor_train/verl/verl/models/mcore/loader.py @@ -42,7 +42,8 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + + pp_rank_idx * num_layers_per_model ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( @@ -53,7 +54,9 @@ def _megatron_calc_layer_map(config): return layer_map -def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False): +def load_state_dict_to_megatron_gptmodel( + state_dict, wrapped_models, config, params_dtype, is_value_model=False +): """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP from megatron.core import mpu @@ -71,13 +74,17 @@ def _get_gpt_model(model): def broadcast_params(module): for param in module.parameters(): torch.distributed.broadcast( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + param.data, + src=mpu.get_data_parallel_src_rank(), + group=mpu.get_data_parallel_group(), ) dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() cp_rank = mpu.get_context_parallel_rank() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank) + src_rank = _megatron_calc_global_rank( + tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank + ) pp_size = mpu.get_pipeline_model_parallel_world_size() virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 mp_group = mpu.get_model_parallel_group() @@ -135,7 +142,9 @@ def _broadcast_tensor(tensor, name) -> torch.Tensor: tensor.data.copy_(weight) dist.broadcast(tensor, src=src_rank, group=mp_group) - def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + def _broadcast_tp_shard_tensor_vocab( + tensor, name, chunk_dim=0, mutate_func=None + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -171,10 +180,12 @@ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None requires_grad=False, ) else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert ( + tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == src_rank: @@ -183,7 +194,9 @@ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None if (i == tp_rank) and (tensor is not None): tensor.data.copy_(sync_tensor) - def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + def _broadcast_tp_shard_tensor( + tensor, name, chunk_dim=0, mutate_func=None + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -218,10 +231,12 @@ def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> t requires_grad=False, ) else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert ( + tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == src_rank: @@ -241,15 +256,22 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + config.intermediate_size * 2, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) + gate_weight_tp = gate_weight[ + i * intermediate_size_tp : (i + 1) * intermediate_size_tp + ] + up_weight_tp = up_weight[ + i * intermediate_size_tp : (i + 1) * intermediate_size_tp + ] + new_gate_up_weight[ + intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1) + ].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -261,7 +283,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens chunk_shape = obj_list[0] if chunk_shape is None: # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + print_rank_0( + f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading" + ) return if tensor is None: @@ -276,7 +300,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape " f"{tensor.shape} != {chunk_shape}" ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False + ) for i in range(tp_size): if torch.distributed.get_rank() == src_rank: @@ -285,7 +311,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens if (i == tp_rank) and (tensor is not None): tensor.data.copy_(sync_tensor) - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + def _broadcast_tp_shard_tensor_qkv( + tensor, q_name, k_name, v_name, bias=False + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -293,34 +321,61 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - tp_size = mpu.get_tensor_model_parallel_world_size() if torch.distributed.get_rank() == src_rank: - assert q_name in state_dict and k_name in state_dict and v_name in state_dict + assert ( + q_name in state_dict and k_name in state_dict and v_name in state_dict + ) full_weight_q = state_dict[q_name] full_weight_k = state_dict[k_name] full_weight_v = state_dict[v_name] - hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + hidden_size_per_head = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) if config.num_key_value_heads >= tp_size: q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + kv_size_tp = ( + hidden_size_per_head * config.num_key_value_heads // tp_size + ) total_size = q_size_tp + 2 * kv_size_tp sizes = [total_size * tp_size] if not bias: sizes.append(config.hidden_size) - new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) + new_weight_qkv = torch.empty( + *sizes, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - num_query_groups_per_partition = models[0].config.num_query_groups // tp_size - new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] - q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0) - k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0) - v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0) + num_query_groups_per_partition = ( + models[0].config.num_query_groups // tp_size + ) + new_weight_qkv_this_tp = new_weight_qkv[ + i * total_size : (i + 1) * total_size + ] + q_part_per_head = torch.chunk( + q_part, num_query_groups_per_partition, dim=0 + ) + k_part_per_head = torch.chunk( + k_part, num_query_groups_per_partition, dim=0 + ) + v_part_per_head = torch.chunk( + v_part, num_query_groups_per_partition, dim=0 + ) total_size_per_head = total_size // num_query_groups_per_partition for j in range(num_query_groups_per_partition): - new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( - torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) + new_weight_qkv_this_tp[ + j * total_size_per_head : (j + 1) * total_size_per_head + ].copy_( + torch.cat( + [ + q_part_per_head[j], + k_part_per_head[j], + v_part_per_head[j], + ], + dim=0, + ) ) else: @@ -330,21 +385,44 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - sizes = [total_size * tp_size] if not bias: sizes.append(config.hidden_size) - new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) + new_weight_qkv = torch.empty( + *sizes, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + start_idx = ( + i * config.num_key_value_heads // tp_size * hidden_size_per_head + ) + end_idx = ( + i * config.num_key_value_heads // tp_size + 1 + ) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] - q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0) - k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0) - v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0) + new_weight_qkv_this_tp = new_weight_qkv[ + i * total_size : (i + 1) * total_size + ] + q_part_per_head = torch.chunk( + q_part, config.num_attention_heads, dim=0 + ) + k_part_per_head = torch.chunk( + k_part, config.num_attention_heads, dim=0 + ) + v_part_per_head = torch.chunk( + v_part, config.num_attention_heads, dim=0 + ) total_size_per_head = total_size // config.num_attention_heads for j in range(config.num_attention_heads): - new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( - torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) + new_weight_qkv_this_tp[ + j * total_size_per_head : (j + 1) * total_size_per_head + ].copy_( + torch.cat( + [ + q_part_per_head[j], + k_part_per_head[j], + v_part_per_head[j], + ], + dim=0, + ) ) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) @@ -357,7 +435,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - chunk_shape = obj_list[0] if chunk_shape is None: # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + print_rank_0( + f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading" + ) return if tensor is None: @@ -368,10 +448,12 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - requires_grad=False, ) else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + assert ( + tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == src_rank: @@ -388,7 +470,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - embed_tokens_weight = None if pp_rank == 0: embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight - _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + _broadcast_tp_shard_tensor_vocab( + embed_tokens_weight, "model.embed_tokens.weight" + ) # Transformer layers # ------------------- @@ -396,36 +480,58 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - for layer in range(config.num_hidden_layers): layer_name = f"model.layers.{layer}" - print_rank_0(f"loading layer #{layer}, with layer_name model.layers.{layer}...") + print_rank_0( + f"loading layer #{layer}, with layer_name model.layers.{layer}..." + ) dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) sync_layer = gpt_model_module.decoder.layers[dst_layer_idx] _broadcast_tensor( - sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.self_attention.linear_qkv.layer_norm_weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.input_layernorm.weight", ) if f"{layer_name}.self_attn.q_norm.weight" in state_dict: _broadcast_tensor( - sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.self_attention.q_layernorm.weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.self_attn.q_norm.weight", ) _broadcast_tensor( - sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.self_attention.k_layernorm.weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.self_attn.k_norm.weight", ) _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.self_attention.linear_qkv.weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.self_attn.q_proj.weight", f"{layer_name}.self_attn.k_proj.weight", f"{layer_name}.self_attn.v_proj.weight", ) if f"{layer_name}.self_attn.q_proj.bias" in state_dict: _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None, + ( + sync_layer.self_attention.linear_qkv.bias + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.self_attn.q_proj.bias", f"{layer_name}.self_attn.k_proj.bias", f"{layer_name}.self_attn.v_proj.bias", @@ -433,12 +539,20 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - ) _broadcast_tp_shard_tensor( - sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.self_attention.linear_proj.weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.self_attn.o_proj.weight", chunk_dim=1, ) _broadcast_tensor( - sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.mlp.linear_fc1.layer_norm_weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.post_attention_layernorm.weight", ) @@ -469,9 +583,15 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - if is_value_model: # if torch.distributed.get_rank() == src_rank: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + if ( + "lm_head.weight" in state_dict + and state_dict["lm_head.weight"].shape[0] == 1 + ): _broadcast_tensor(lm_head_weight, "lm_head.weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + elif ( + "reward_head.weight" in state_dict + and state_dict["reward_head.weight"].shape[0] == 1 + ): _broadcast_tensor(lm_head_weight, "reward_head.weight") print_rank_0("load lm_head from value_head weight") else: @@ -489,4 +609,6 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - broadcast_params(wrapped_model) pass get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") + print_rank_0( + f"loading megatron ckpt done, time elapsed {time.time() - start_time}s" + ) diff --git a/Agent0/executor_train/verl/verl/models/mcore/mbridge.py b/Agent0/executor_train/verl/verl/models/mcore/mbridge.py index 35c32d6..f1d8227 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/mbridge.py +++ b/Agent0/executor_train/verl/verl/models/mcore/mbridge.py @@ -15,9 +15,14 @@ try: from mbridge import AutoBridge - from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model + from mbridge.utils.post_creation_callbacks import ( + freeze_moe_router, + make_value_model, + ) except ImportError: - print("mbridge package not found. Please install mbridge with `pip install verl[mcore]` or `pip install mbridge`") + print( + "mbridge package not found. Please install mbridge with `pip install verl[mcore]` or `pip install mbridge`" + ) raise __all__ = ["AutoBridge", "make_value_model", "freeze_moe_router"] diff --git a/Agent0/executor_train/verl/verl/models/mcore/model_forward.py b/Agent0/executor_train/verl/verl/models/mcore/model_forward.py index e70e11f..83f738d 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/model_forward.py +++ b/Agent0/executor_train/verl/verl/models/mcore/model_forward.py @@ -16,7 +16,12 @@ from verl.utils.megatron_utils import unwrap_model -from .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding +from .util import ( + postprocess_packed_seqs, + preprocess_packed_seqs, + recover_left_padding, + remove_left_padding, +) def gptmodel_forward( @@ -36,7 +41,9 @@ def gptmodel_forward( post_process = unwrap_model(model).post_process if pack_seqs: batch_size, seq_len = attention_mask.shape[:2] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs( + input_ids, attention_mask, pre_process=pre_process + ) input_ids_rmpad = input_ids_rmpad.contiguous() output_orig = model( input_ids=input_ids_rmpad, @@ -52,23 +59,47 @@ def gptmodel_forward( output_dict = logits_processor(output_orig, **args) output = { k: postprocess_packed_seqs( - v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + v, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, ) for k, v in output_dict.items() } else: output = postprocess_packed_seqs( - output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + output_orig, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, ) else: - assert logits_processor is None, "logits_processor is not supported for non-packed sequence" + assert ( + logits_processor is None + ), "logits_processor is not supported for non-packed sequence" batch_size, sequence_length = attention_mask.shape new_input_ids, new_attention_mask, new_position_ids = remove_left_padding( - input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process + input_ids, + attention_mask, + position_ids, + sequence_parallel, + pre_process=pre_process, + ) + output = model( + input_ids=new_input_ids, + attention_mask=new_attention_mask, + position_ids=new_position_ids, ) - output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids) output = recover_left_padding( - output, new_attention_mask, attention_mask, sequence_length, post_process=post_process + output, + new_attention_mask, + attention_mask, + sequence_length, + post_process=post_process, ) if value_model and post_process: output = output[..., 0] @@ -90,18 +121,26 @@ def gptmodel_forward_qwen2_5_vl( ): from megatron.core import parallel_state as mpu - assert mpu.get_context_parallel_world_size() == 1, "qwen2_5_vl's context parallel is not accurate yet" + assert ( + mpu.get_context_parallel_world_size() == 1 + ), "qwen2_5_vl's context parallel is not accurate yet" pre_process = unwrap_model(model).pre_process post_process = unwrap_model(model).post_process pixel_values = ( - multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None + multi_modal_inputs["pixel_values"].to(input_ids.device) + if "pixel_values" in multi_modal_inputs + else None ) image_grid_thw = ( - multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None + multi_modal_inputs["image_grid_thw"].to(input_ids.device) + if "image_grid_thw" in multi_modal_inputs + else None ) if pack_seqs: batch_size, seq_len = attention_mask.shape[:2] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs( + input_ids, attention_mask, pre_process=True + ) input_ids_rmpad = input_ids_rmpad.contiguous() output_orig = model( input_ids=input_ids_rmpad, @@ -120,18 +159,32 @@ def gptmodel_forward_qwen2_5_vl( output_dict = logits_processor(output_orig, **args) output = { k: postprocess_packed_seqs( - v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + v, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, ) for k, v in output_dict.items() } else: output = postprocess_packed_seqs( - output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + output_orig, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, ) else: batch_size, sequence_length = attention_mask.shape new_input_ids, new_attention_mask, new_position_ids = remove_left_padding( - input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process + input_ids, + attention_mask, + position_ids, + sequence_parallel, + pre_process=pre_process, ) output = model( input_ids=new_input_ids, @@ -141,7 +194,11 @@ def gptmodel_forward_qwen2_5_vl( image_grid_thw=image_grid_thw, ) output = recover_left_padding( - output, new_attention_mask, attention_mask, sequence_length, post_process=post_process + output, + new_attention_mask, + attention_mask, + sequence_length, + post_process=post_process, ) if value_model and post_process: output = output[..., 0] diff --git a/Agent0/executor_train/verl/verl/models/mcore/model_forward_fused.py b/Agent0/executor_train/verl/verl/models/mcore/model_forward_fused.py index fc55ef1..401ad13 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/model_forward_fused.py +++ b/Agent0/executor_train/verl/verl/models/mcore/model_forward_fused.py @@ -76,10 +76,14 @@ def fused_forward_gptmodel( post_process: bool = unwrap_model(model).post_process batch_size, seq_len = attention_mask.shape[:2] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs( + input_ids, attention_mask, pre_process=pre_process + ) input_ids_rmpad = input_ids_rmpad.contiguous() labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True) - labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True) + labels_mask_rmpad, _ = preprocess_packed_seqs( + labels_mask, attention_mask, pre_process=True + ) labels_rmpad = labels_rmpad.contiguous() labels_mask_rmpad = labels_mask_rmpad.contiguous() @@ -121,16 +125,24 @@ def fused_forward_qwen2_5_vl( post_process = unwrap_model(model).post_process pixel_values = ( - multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None + multi_modal_inputs["pixel_values"].to(input_ids.device) + if "pixel_values" in multi_modal_inputs + else None ) image_grid_thw = ( - multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None + multi_modal_inputs["image_grid_thw"].to(input_ids.device) + if "image_grid_thw" in multi_modal_inputs + else None ) batch_size, seq_len = attention_mask.shape[:2] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) + input_ids_rmpad, packed_seq_params = preprocess_packed_seqs( + input_ids, attention_mask, pre_process=True + ) labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True) - labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True) + labels_mask_rmpad, _ = preprocess_packed_seqs( + labels_mask, attention_mask, pre_process=True + ) labels_rmpad = labels_rmpad.contiguous() labels_mask_rmpad = labels_mask_rmpad.contiguous() input_ids_rmpad = input_ids_rmpad.contiguous() @@ -198,9 +210,14 @@ def _fused_GPTModel_forward( rotary_pos_emb = None rotary_pos_cos = None rotary_pos_sin = None - if self.position_embedding_type == "rope" and not self.config.multi_latent_attention: + if ( + self.position_embedding_type == "rope" + and not self.config.multi_latent_attention + ): if not self.training and self.config.flash_decode and inference_context: - assert inference_context.is_static_batching(), "GPTModel currently only supports static inference batching." + assert ( + inference_context.is_static_batching() + ), "GPTModel currently only supports static inference batching." # Flash decoding uses precomputed cos and sin for RoPE rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( inference_context.max_sequence_length, @@ -208,13 +225,21 @@ def _fused_GPTModel_forward( ) else: rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_context, self.decoder, decoder_input, self.config, packed_seq_params + inference_context, + self.decoder, + decoder_input, + self.config, + packed_seq_params, ) rotary_pos_emb = self.rotary_pos_emb( rotary_seq_len, - packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == "thd", + packed_seq=packed_seq_params is not None + and packed_seq_params.qkv_format == "thd", ) - elif self.position_embedding_type == "mrope" and not self.config.multi_latent_attention: + elif ( + self.position_embedding_type == "mrope" + and not self.config.multi_latent_attention + ): if self.training or not self.config.flash_decode: rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) else: @@ -231,7 +256,8 @@ def _fused_GPTModel_forward( and not self.training ): sequence_len_offset = torch.tensor( - [inference_context.sequence_len_offset] * inference_context.current_batch_size, + [inference_context.sequence_len_offset] + * inference_context.current_batch_size, dtype=torch.int32, device=rotary_pos_cos.device, # Co-locate this with the rotary tensors ) @@ -257,7 +283,9 @@ def _fused_GPTModel_forward( # Process inference output. if inference_context and not inference_context.is_static_batching(): - hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) + hidden_states = inference_context.last_token_logits( + hidden_states.squeeze(1).unsqueeze(0) + ).unsqueeze(1) # logits and loss output_weight = None diff --git a/Agent0/executor_train/verl/verl/models/mcore/model_initializer.py b/Agent0/executor_train/verl/verl/models/mcore/model_initializer.py index 4c01b12..52f7379 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/model_initializer.py +++ b/Agent0/executor_train/verl/verl/models/mcore/model_initializer.py @@ -17,7 +17,10 @@ # use mcore transformer config to initialize the model from abc import ABC, abstractmethod -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_decoder_block_spec, + get_gpt_mtp_block_spec, +) from megatron.core.models.gpt.gpt_model import GPTModel from .config_converter import PretrainedConfig, TransformerConfig @@ -33,7 +36,8 @@ def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig): @abstractmethod def get_transformer_layer_spec(self): """Get the transformer layer specification. - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py""" + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py + """ pass def get_rope_scaling_args(self) -> dict: @@ -42,7 +46,9 @@ def get_rope_scaling_args(self) -> dict: if "rope_scaling" in self.hf_config: if self.hf_config.rope_scaling is not None: # assert self.hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" - rope_scaling_args["seq_len_interpolation_factor"] = self.hf_config.rope_scaling["factor"] + rope_scaling_args["seq_len_interpolation_factor"] = ( + self.hf_config.rope_scaling["factor"] + ) return rope_scaling_args def initialize( @@ -83,10 +89,14 @@ def initialize( ) if post_process and value: - from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer + from verl.models.llama.megatron.layers.parallel_linear import ( + LinearForLastLayer, + ) model.output_layer = LinearForLastLayer( - input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig + input_size=self.tfconfig.hidden_size, + output_size=1, + config=self.tfconfig, ) return model @@ -96,7 +106,9 @@ class DenseModel(BaseModelInitializer): """Initializer for dense models like Llama and Qwen2.""" def get_transformer_layer_spec(self): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" + assert ( + self.tfconfig.normalization == "RMSNorm" + ), "only RMSNorm is supported for now" return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) @@ -104,12 +116,18 @@ class Qwen2MoEModel(BaseModelInitializer): """Initializer for Qwen2 MoE models.""" def get_transformer_layer_spec(self): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + assert ( + self.tfconfig.normalization == "RMSNorm" + ), "only RMSNorm is supported for now" + transformer_layer_spec = get_gpt_decoder_block_spec( + self.tfconfig, use_transformer_engine=True + ) # Patch layer spec for shared experts for i in range(len(transformer_layer_spec.layer_specs)): - transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params["gate"] = True + transformer_layer_spec.layer_specs[ + i + ].submodules.mlp.submodules.shared_experts.params["gate"] = True return transformer_layer_spec @@ -127,8 +145,12 @@ class MixtralModel(BaseModelInitializer): """Initializer for Mixtral models.""" def get_transformer_layer_spec(self): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + assert ( + self.tfconfig.normalization == "RMSNorm" + ), "only RMSNorm is supported for now" + transformer_layer_spec = get_gpt_decoder_block_spec( + self.tfconfig, use_transformer_engine=True + ) return transformer_layer_spec def initialize(self, **kwargs): @@ -144,8 +166,12 @@ class Qwen3MoEModel(BaseModelInitializer): """Initializer for Qwen3 MoE models.""" def get_transformer_layer_spec(self): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + assert ( + self.tfconfig.normalization == "RMSNorm" + ), "only RMSNorm is supported for now" + transformer_layer_spec = get_gpt_decoder_block_spec( + self.tfconfig, use_transformer_engine=True + ) return transformer_layer_spec def initialize(self, **kwargs): @@ -162,7 +188,9 @@ class DeepseekV3Model(BaseModelInitializer): """Initializer for DeepseekV3 models.""" def get_transformer_layer_spec(self): - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + transformer_layer_spec = get_gpt_decoder_block_spec( + self.tfconfig, use_transformer_engine=True + ) return transformer_layer_spec def get_rope_scaling_args(self) -> dict: @@ -180,7 +208,9 @@ def initialize( # MTP if self.tfconfig.mtp_num_layers is not None: transformer_layer_spec = self.get_transformer_layer_spec() - mtp_block_spec = get_gpt_mtp_block_spec(self.tfconfig, transformer_layer_spec, use_transformer_engine=True) + mtp_block_spec = get_gpt_mtp_block_spec( + self.tfconfig, transformer_layer_spec, use_transformer_engine=True + ) kwargs["mtp_block_spec"] = mtp_block_spec model = super().initialize(**kwargs) @@ -195,7 +225,9 @@ class Qwen25VLModel(BaseModelInitializer): """Initializer for Qwen2.5 VL models.""" def get_transformer_layer_spec(self): - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) + transformer_layer_spec = get_gpt_decoder_block_spec( + self.tfconfig, use_transformer_engine=True + ) return transformer_layer_spec def initialize( @@ -213,11 +245,20 @@ def initialize( transformer_layer_spec = self.get_transformer_layer_spec() - from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TERowParallelLinear, + ) from megatron.core.models.gpt.moe_module_specs import MLPSubmodules - from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec + from megatron.core.models.vision.vit_layer_specs import ( + get_vit_layer_with_transformer_engine_spec, + ) - from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config + from .qwen2_5_vl import ( + Qwen2_5VLModel, + get_vision_model_config, + get_vision_projection_config, + ) vision_transformer_config = get_vision_model_config(deepcopy(tfconfig)) vision_transformer_config.pipeline_model_parallel_size = 1 @@ -254,7 +295,9 @@ def initialize( ) if post_process and value: - from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer + from verl.models.llama.megatron.layers.parallel_linear import ( + LinearForLastLayer, + ) qwen25_vl_model.language_model.output_layer = LinearForLastLayer( input_size=tfconfig.hidden_size, output_size=1, config=tfconfig diff --git a/Agent0/executor_train/verl/verl/models/mcore/patch_v012.py b/Agent0/executor_train/verl/verl/models/mcore/patch_v012.py index d54a3eb..bbe54ce 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/patch_v012.py +++ b/Agent0/executor_train/verl/verl/models/mcore/patch_v012.py @@ -44,9 +44,13 @@ def patch_get_query_key_value_tensors( """ # s = sequence length, b = batch size, h = hidden size, n = num attention heads # Attention heads [s, b, n*h] - assert hidden_states.ndim == 3, f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" + assert ( + hidden_states.ndim == 3 + ), f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" - inference_context = deprecate_inference_params(inference_context, inference_params) + inference_context = deprecate_inference_params( + inference_context, inference_params + ) # ========================================= # Prepare RoPE and seqlen related params @@ -58,7 +62,9 @@ def patch_get_query_key_value_tensors( # rotary_pos_emb:[s, b, 1, 64] mscale = 1.0 if self.config.rope_type == "rope": - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == "thd" + packed_seq = ( + packed_seq_params is not None and packed_seq_params.qkv_format == "thd" + ) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) else: rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len) @@ -92,12 +98,17 @@ def patch_get_query_key_value_tensors( # elif linear_kv_down_proj is Linear: # kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)] kv_combined, _ = self.linear_kv_down_proj(hidden_states) - if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim: + if ( + kv_combined.size(-1) + != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim + ): # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)] kv_combined = gather_from_tensor_model_parallel_region(kv_combined) # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim] kv_compressed, k_pos_emb = torch.split( - kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 + kv_combined, + [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], + dim=-1, ) if self.config.sequence_parallel: # kv_compressed:[s / TP, b, kv_lora_rank] @@ -105,7 +116,9 @@ def patch_get_query_key_value_tensors( else: # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim] kv_compressed, k_pos_emb = torch.split( - kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 + kv_combined, + [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], + dim=-1, ) if parallel_state.get_tensor_model_parallel_world_size() > 1: # k_pos_emb: [s, b, qk_pos_emb_head_dim] @@ -116,7 +129,9 @@ def patch_get_query_key_value_tensors( # ========================================= # QKV up projection and RoPE apply # ========================================= - def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb): + def qkv_up_proj_and_rope_apply( + q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb + ): if self.config.q_lora_rank is not None: q, _ = self.linear_q_up_proj(q_compressed) else: @@ -126,7 +141,9 @@ def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_po q_len, bsz, _ = q.size() # q: [s, b, n, 192] - q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim) + q = q.view( + q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim + ) # kv: [s, b, 2048] kv, _ = self.linear_kv_up_proj(kv_compressed) @@ -155,10 +172,14 @@ def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_po k_pos_emb = torch.unsqueeze(k_pos_emb, 2) # q: [s, b, n, 128], q_pos_emb: [s, b, n, 64] - q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1) + q_no_pe, q_pos_emb = torch.split( + q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1 + ) # k_no_pe: [s, b, n, 128], value: [s, b, n, 128] - k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1) + k_no_pe, value = torch.split( + kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1 + ) if packed_seq_params is not None: cu_seqlens_q = packed_seq_params.cu_seqlens_q @@ -190,11 +211,15 @@ def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_po # query: [s, b, n, 192] query = torch.cat([q_no_pe, q_pos_emb], dim=-1) if packed_seq_params is not None: - k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1) + k_pos_emb = k_pos_emb.expand( + -1, self.num_attention_heads_per_partition, -1 + ) key = torch.cat([k_no_pe, k_pos_emb], dim=-1) else: # key: [s, b, n, 192] - k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1) + k_pos_emb = k_pos_emb.expand( + -1, -1, self.num_attention_heads_per_partition, -1 + ) key = torch.cat([k_no_pe, k_pos_emb], dim=-1) query = query.contiguous() @@ -205,10 +230,16 @@ def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_po if self.recompute_up_proj: self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput() query, key, value = self.qkv_up_checkpoint.checkpoint( - qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb + qkv_up_proj_and_rope_apply, + q_compressed, + kv_compressed, + k_pos_emb, + rotary_pos_emb, ) else: - query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb) + query, key, value = qkv_up_proj_and_rope_apply( + q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb + ) return query, key, value diff --git a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/attention.py b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/attention.py index 91a27cc..7bbfaf6 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/attention.py +++ b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/attention.py @@ -63,15 +63,21 @@ def forward( """ - inference_context = deprecate_inference_params(inference_context, inference_params) + inference_context = deprecate_inference_params( + inference_context, inference_params + ) if inference_context and inference_context.is_dynamic_batching(): - assert flash_decode_and_prefill_kernel is not None, ( - "Internal use only: install package `nvidia_chunked_flash_attn`." - ) + assert ( + flash_decode_and_prefill_kernel is not None + ), "Internal use only: install package `nvidia_chunked_flash_attn`." # hidden_states: [sq, b, h] - if self.config.flash_decode and not self.training and inference_context is not None: + if ( + self.config.flash_decode + and not self.training + and inference_context is not None + ): rotary_pos_emb = None else: assert rotary_pos_cos is None and rotary_pos_sin is None @@ -85,7 +91,9 @@ def forward( # ===================== # Get the query, key and value tensors based on the type of attention - # self or cross attn. - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + query, key, value = self.get_query_key_value_tensors( + hidden_states, key_value_states + ) # =================================================== # Adjust key, value, and rotary_pos_emb for inference @@ -102,7 +110,9 @@ def forward( ): assert self.layer_number in inference_context.key_value_memory_dict assert inference_context.sequence_len_offset is not None - inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] + inference_key_memory, inference_value_memory = ( + inference_context.key_value_memory_dict[self.layer_number] + ) output = self.flash_decode( sequence_len_offset=sequence_len_offset, query_layer=query, @@ -118,15 +128,17 @@ def forward( output, bias = self.linear_proj(context_layer) return output, bias - query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( - inference_context, - query, - key, - value, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, + query, key, value, rotary_pos_emb, attn_mask_type = ( + self._adjust_key_value_for_inference( + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) ) if packed_seq_params is not None: @@ -155,11 +167,17 @@ def forward( if q_pos_emb is not None: # TODO VIJAY: simplify if inference_context is None or inference_context.is_static_batching(): - query = apply_rotary_pos_emb_absolute(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q) + query = apply_rotary_pos_emb_absolute( + query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q + ) else: - query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q) + query = inference_context.apply_rotary_emb_query( + query, q_pos_emb, self.config, cu_seqlens_q + ) if k_pos_emb is not None: - key = apply_rotary_pos_emb_absolute(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) + key = apply_rotary_pos_emb_absolute( + key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv + ) # TODO, can apply positional embedding to value_layer so it has # absolute positional embedding. diff --git a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/model.py b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/model.py index 74e4406..45b4508 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/model.py +++ b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/model.py @@ -97,11 +97,15 @@ def __init__( super().__init__(config=language_transformer_config) # patch self_attention to use qwen2_5_vl attention - vision_transformer_layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention + vision_transformer_layer_spec.submodules.self_attention.module = ( + Qwen2_5VLSelfAttention + ) for layer_spec in language_transformer_layer_spec.layer_specs: layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention - logging.getLogger(__name__).warning("Qwen2VL model is under development and may be missing features.") + logging.getLogger(__name__).warning( + "Qwen2VL model is under development and may be missing features." + ) self.pre_process = pre_process self.post_process = post_process @@ -115,7 +119,10 @@ def __init__( self.image_token_id = image_token_id self.video_token_id = video_token_id - self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size + self.square_merge_size = ( + vision_projection_config.ffn_hidden_size + // vision_transformer_config.hidden_size + ) # This attribute is needed to check if an all-reduce is required # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. @@ -147,7 +154,9 @@ def __init__( scatter_embedding_sequence_parallel=False, ) - self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + self.share_embeddings_and_output_weights = ( + self.language_model.share_embeddings_and_output_weights + ) def shared_embedding_or_output_weight(self): """This is a convenience method to surface the language model's word embeddings, which is @@ -161,14 +170,21 @@ def set_input_tensor(self, input_tensor) -> None: # gives us non-lists or None if not isinstance(input_tensor, list): input_tensor = [input_tensor] - assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen2VL" + assert ( + len(input_tensor) == 1 + ), "input_tensor should only be length 1 for Qwen2VL" if self.pre_process: self.encoder_hidden_state = input_tensor[0] else: self.language_model.set_input_tensor(input_tensor[0]) - def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool): + def freeze( + self, + freeze_language_model: bool, + freeze_vision_model: bool, + freeze_vision_projection: bool, + ): """Freeze model modules. Make specific modules non-trainable by setting requires_grad to False for the module's parameters. @@ -238,10 +254,12 @@ def forward( vision_data = torch.cat([vision_data, pixel_values_videos], dim=0) video_start_index = image_mask.sum().item() + video_mask.sum().item() use_inference_kv_cache = ( - inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + inference_params is not None + and "image_tokens_count" in inference_params.key_value_memory_dict ) use_inference_kv_cache = ( - inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + inference_params is not None + and "image_tokens_count" in inference_params.key_value_memory_dict ) if use_inference_kv_cache: raise NotImplementedError() @@ -293,22 +311,28 @@ def forward( ) # [text_seq_len, b, h_language] if image_embeds is not None or video_embeds is not None: - combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + combined_embeddings = combined_embeddings.transpose( + 0, 1 + ).contiguous() if image_embeds is not None: image_mask = (input_ids == self.image_token_id).contiguous() if image_mask.sum() > 0: combined_embeddings = combined_embeddings.clone() combined_embeddings[image_mask] = image_embeds.to( - dtype=combined_embeddings.dtype, device=combined_embeddings.device + dtype=combined_embeddings.dtype, + device=combined_embeddings.device, ) if video_embeds is not None: video_mask = (input_ids == self.video_token_id).contiguous() if video_mask.sum() > 0: combined_embeddings = combined_embeddings.clone() combined_embeddings[video_mask] = video_embeds.to( - dtype=combined_embeddings.dtype, device=combined_embeddings.device + dtype=combined_embeddings.dtype, + device=combined_embeddings.device, ) - combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + combined_embeddings = combined_embeddings.transpose( + 0, 1 + ).contiguous() else: combined_embeddings = self.language_model.embedding( @@ -316,14 +340,21 @@ def forward( position_ids=None, # NOTE: disable ) # [text_seq_len, b, h_language] if self.config.sequence_parallel: - combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) + combined_embeddings = ( + tensor_parallel.scatter_to_sequence_parallel_region( + combined_embeddings + ) + ) combined_embeddings = combined_embeddings.contiguous() else: combined_embeddings = None from .rope_utils import get_rope_index position_ids, _ = get_rope_index( - input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, attention_mask=attention_mask + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, ) output = self.language_model( diff --git a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/rope_utils.py b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/rope_utils.py index fadc74d..1c5cebd 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/rope_utils.py +++ b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/rope_utils.py @@ -107,7 +107,9 @@ def get_rope_index( video_token_id = 151656 vision_start_token_id = 151652 mrope_position_deltas = [] - if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) @@ -123,7 +125,9 @@ def get_rope_index( for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() @@ -171,8 +175,12 @@ def get_rope_index( ) text_len = ed - st - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) range_tensor = torch.arange(llm_grid_t).view(-1, 1) expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) @@ -182,27 +190,53 @@ def get_rope_index( time_tensor_long = time_tensor.long() t_index = time_tensor_long.flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) - mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) return position_ids, mrope_position_deltas else: if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + position_ids = ( + position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: position_ids = ( @@ -233,7 +267,9 @@ def apply_rotary_pos_emb_thd_absolute( Returns: Tensor: Shape [t, h, d]. The input tensor after applying RoPE. """ - return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1) + return _apply_rotary_pos_emb_bshd( + t[:, None], freqs, rotary_interleaved=rotary_interleaved + ).squeeze(1) def apply_rotary_pos_emb_absolute( @@ -253,7 +289,9 @@ def apply_rotary_pos_emb_absolute( if cu_seqlens is None: # NOTE: TE backends do not support mRoPE in bshd format when bs > 1 if freqs.shape[1] > 1: - return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + return _apply_rotary_pos_emb_bshd( + t, freqs, rotary_interleaved=config.rotary_interleaved + ) else: return fused_apply_rotary_pos_emb(t, freqs) else: @@ -261,6 +299,10 @@ def apply_rotary_pos_emb_absolute( return fused_apply_rotary_pos_emb(t[:, None], freqs).squeeze(1) else: if cu_seqlens is None: - return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + return _apply_rotary_pos_emb_bshd( + t, freqs, rotary_interleaved=config.rotary_interleaved + ) else: - return apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved) + return apply_rotary_pos_emb_thd_absolute( + t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved + ) diff --git a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_config.py b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_config.py index 0631c90..57ca63f 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_config.py +++ b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_config.py @@ -31,7 +31,9 @@ def get_vision_model_config(config: TransformerConfig) -> TransformerConfig: config.ffn_hidden_size = 3456 if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: - config.num_layers = 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size() # depth + config.num_layers = ( + 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size() + ) # depth else: config.num_layers = 32 # depth config.num_attention_heads = 16 # num_heads diff --git a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_model.py b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_model.py index 06b4fd3..66f47e7 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_model.py +++ b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_model.py @@ -46,14 +46,26 @@ def __init__( self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype hidden_states = hidden_states.view( - -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( + -1, self.embed_dim ) - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) return hidden_states @@ -65,7 +77,9 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) return freqs.float() @@ -141,7 +155,10 @@ def __init__( if self.post_process: self.projection = MultimodalProjector( - projection_config, projection_layer_spec, projection_type, projection_config.ffn_hidden_size + projection_config, + projection_layer_spec, + projection_type, + projection_config.ffn_hidden_size, ) else: self.projection = None @@ -192,14 +209,18 @@ def get_window_index(self, grid_thw): window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size, ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size @@ -222,7 +243,9 @@ def get_window_index(self, grid_thw): index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) @@ -262,12 +285,16 @@ def forward( cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) seq_len, _ = vision_data.size() - vision_data = vision_data.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + vision_data = vision_data.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) vision_data = vision_data[window_index, :, :] vision_data = vision_data.reshape(seq_len, 1, -1) rotary_pos_emb = self.rot_pos_emb(grid_thw) - rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, 1, 1, -1).repeat(1, 1, 1, 2) @@ -293,7 +320,9 @@ def build_packed_seq_params( ) -> PackedSeqParams: # NOTE: each frame is a sequence (rather than each grid) if grid_thw is not None: - seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) + seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ) cu_seqlens = seqlens.cumsum(dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).int() else: diff --git a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py index 8f765a0..8cd9122 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py +++ b/Agent0/executor_train/verl/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py @@ -34,7 +34,9 @@ def _checkpointed_forward( """Forward method with activation checkpointing.""" def custom(start: int, end: int): - def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb): + def custom_forward( + hidden_states, attention_mask, context, context_mask, rotary_pos_emb + ): for index in range(start, end): if index in fullatt_block_indexes: packed_seq_params_now = packed_seq_params_full @@ -105,12 +107,19 @@ def checkpoint_handler(forward_func): recompute_skip_num_layers += 1 if ( layer_idx >= recompute_skip_num_layers - and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + and layer_idx + < self.config.recompute_num_layers + recompute_skip_num_layers ): - hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + hidden_states, context = checkpoint_handler( + custom(layer_idx, layer_idx + 1) + ) else: hidden_states, context = custom(layer_idx, layer_idx + 1)( - hidden_states, attention_mask, context, context_mask, rotary_pos_emb + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, ) else: raise ValueError("Invalid activation recompute method.") @@ -164,7 +173,9 @@ def forward( [s, b, h], and optionally the updated context tensor if cross-attention is used. """ - inference_context = deprecate_inference_params(inference_context, inference_params) + inference_context = deprecate_inference_params( + inference_context, inference_params + ) # Delete the obsolete reference to the initial input tensor if necessary if isinstance(hidden_states, WrappedTensor): @@ -193,7 +204,9 @@ def forward( # likely redundant, since p2p_communication.py (likely originator) # already creates viewless tensors. That said, make_viewless_tensor() # is called here to be future-proof and corner-case-proof. - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) if self.config.sequence_parallel: rng_context = tensor_parallel.get_cuda_rng_tracker().fork() @@ -205,9 +218,15 @@ def forward( # if we are using other fp8 recipes, then the context manager enter&exit are free # we can wrap fp8_context within the for loop over layers, so that we can fine-grained # control which layer will be fp8 or bf16 - use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed - use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed - outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + use_outer_fp8_context = ( + self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed + ) + use_inner_fp8_context = ( + self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + ) + outer_fp8_context = ( + get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + ) with rng_context, outer_fp8_context: # Forward pass. @@ -226,7 +245,9 @@ def forward( else: for l_no, layer in enumerate(self.layers): inner_fp8_context = ( - get_fp8_context(self.config, layer.layer_number - 1) if use_inner_fp8_context else nullcontext() + get_fp8_context(self.config, layer.layer_number - 1) + if use_inner_fp8_context + else nullcontext() ) if l_no in fullatt_block_indexes: packed_seq_params_now = packed_seq_params_full @@ -252,7 +273,9 @@ def forward( and self.config.cpu_offloading and self.group_prefetch_offload_commit_async is not None ): - hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + hidden_states = self.group_prefetch_offload_commit_async( + hidden_states + ) # Final layer norm. if self.final_layernorm is not None: @@ -260,6 +283,8 @@ def forward( # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) return hidden_states diff --git a/Agent0/executor_train/verl/verl/models/mcore/registry.py b/Agent0/executor_train/verl/verl/models/mcore/registry.py index 23f01e8..e78f33b 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/registry.py +++ b/Agent0/executor_train/verl/verl/models/mcore/registry.py @@ -73,7 +73,9 @@ class SupportedModel(Enum): # Registry for model configuration converters -MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = { +MODEL_CONFIG_CONVERTER_REGISTRY: dict[ + SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig] +] = { SupportedModel.LLAMA: hf_to_mcore_config_dense, SupportedModel.QWEN2: hf_to_mcore_config_dense, SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe, @@ -154,7 +156,9 @@ def get_supported_model(model_type: str) -> SupportedModel: def hf_to_mcore_config( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs + hf_config: PretrainedConfig, + dtype: torch.dtype, + **override_transformer_config_kwargs, ) -> TransformerConfig: """Convert huggingface PretrainedConfig to mcore TransformerConfig. @@ -166,9 +170,13 @@ def hf_to_mcore_config( Returns: The mcore TransformerConfig. """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + assert ( + len(hf_config.architectures) == 1 + ), "Only one architecture is supported for now" model = get_supported_model(hf_config.architectures[0]) - return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs) + return MODEL_CONFIG_CONVERTER_REGISTRY[model]( + hf_config, dtype, **override_transformer_config_kwargs + ) def init_mcore_model( @@ -196,7 +204,9 @@ def init_mcore_model( Returns: The initialized model. """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + assert ( + len(hf_config.architectures) == 1 + ), "Only one architecture is supported for now" model = get_supported_model(hf_config.architectures[0]) initializer_cls = MODEL_INITIALIZER_REGISTRY[model] initializer = initializer_cls(tfconfig, hf_config) @@ -213,7 +223,9 @@ def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable: """ Get the forward function for given model architecture. """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + assert ( + len(hf_config.architectures) == 1 + ), "Only one architecture is supported for now" model = get_supported_model(hf_config.architectures[0]) return MODEL_FORWARD_REGISTRY[model] @@ -222,16 +234,22 @@ def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable: """ Get the forward function for given model architecture. """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + assert ( + len(hf_config.architectures) == 1 + ), "Only one architecture is supported for now" model = get_supported_model(hf_config.architectures[0]) return MODEL_FORWARD_FUSED_REGISTRY[model] -def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable: +def get_mcore_weight_converter( + hf_config: PretrainedConfig, dtype: torch.dtype +) -> Callable: """ Get the weight converter for given model architecture. """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + assert ( + len(hf_config.architectures) == 1 + ), "Only one architecture is supported for now" model = get_supported_model(hf_config.architectures[0]) tfconfig = hf_to_mcore_config(hf_config, dtype) return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig) diff --git a/Agent0/executor_train/verl/verl/models/mcore/saver.py b/Agent0/executor_train/verl/verl/models/mcore/saver.py index 2a954b2..a9361fe 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/saver.py +++ b/Agent0/executor_train/verl/verl/models/mcore/saver.py @@ -28,7 +28,11 @@ def _megatron_calc_global_rank( - tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0 + tp_rank: int = 0, + dp_rank: int = 0, + pp_rank: int = 0, + cp_rank: int = 0, + ep_rank: int = 0, ): """Calculate global rank with support for CP/EP parallelism""" @@ -41,9 +45,9 @@ def _megatron_calc_global_rank( # Verify total GPU count matches (must be consistent with parallel_state.py) total_size = tp_size * dp_size * pp_size * cp_size - assert total_size == torch.distributed.get_world_size(), ( - f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}" - ) + assert ( + total_size == torch.distributed.get_world_size() + ), f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}" # Core calculation logic (corresponds to RankGenerator order parameter) # Assumes default order is "tp-cp-ep-dp-pp" @@ -69,7 +73,8 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + + pp_rank_idx * num_layers_per_model ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( @@ -80,7 +85,9 @@ def _megatron_calc_layer_map(config): return layer_map -def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): +def merge_megatron_ckpt_gptmodel( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): """Merge sharded parameters of a Megatron module into a merged checkpoint. Args: @@ -123,10 +130,10 @@ def _get_gpt_model(model): for i, wrapped_model in enumerate(wrapped_models): models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].decoder.layers) == num_layers_per_model, ( - "len model layers {} not equal to num_layers_per_model {}".format( - len(models[i].decoder.layers), num_layers_per_model - ) + assert ( + len(models[i].decoder.layers) == num_layers_per_model + ), "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].decoder.layers), num_layers_per_model ) state_dict = dict() @@ -142,7 +149,9 @@ def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: """broadcast tensor across mp_group""" nonlocal state_dict nonlocal mp_group - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + src_rank = _megatron_calc_global_rank( + tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank + ) if torch.distributed.get_rank() == src_rank: if tensor is None: @@ -177,13 +186,17 @@ def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: if torch.distributed.get_rank() == 0: state_dict[name] = _get_cpu_tensor(weight) - def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + def _broadcast_tp_shard_tensor( + tensor, name, src_pp_rank, concat_dim=0, mutate_func=None + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group # tp_rank = mpu.get_tensor_model_parallel_rank() tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + src_rank = _megatron_calc_global_rank( + tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank + ) chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None @@ -205,8 +218,14 @@ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_f chunk_tensors = [None] * tp_size for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + cur_src_rank = _megatron_calc_global_rank( + tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank + ) + sync_tensor = ( + tensor + if torch.distributed.get_rank() == cur_src_rank + else buffer_tensor + ) dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) if torch.distributed.get_rank() == 0: @@ -218,13 +237,17 @@ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_f full_tensor = mutate_func(full_tensor) state_dict[name] = full_tensor - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + def _broadcast_tp_shard_tensor_gate_up( + tensor, gate_name, up_name, src_pp_rank + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group # tp_rank = mpu.get_tensor_model_parallel_rank() tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + src_rank = _megatron_calc_global_rank( + tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank + ) chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None @@ -233,7 +256,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) chunk_shape = obj_list[0] if chunk_shape is None: # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + print_rank_0( + f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting" + ) return buffer_tensor = torch.empty( @@ -246,8 +271,14 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) chunk_tensors = [None] * tp_size for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + cur_src_rank = _megatron_calc_global_rank( + tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank + ) + sync_tensor = ( + tensor + if torch.distributed.get_rank() == cur_src_rank + else buffer_tensor + ) dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) if torch.distributed.get_rank() == 0: @@ -259,7 +290,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) gate_weight_list = [] up_weight_list = [] for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_up_weight_tp = full_tensor[ + intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1) + ] gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] gate_weight_list.append(gate_weight_tp) @@ -274,7 +307,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): nonlocal mp_group # tp_rank = mpu.get_tensor_model_parallel_rank() tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) + src_rank = _megatron_calc_global_rank( + tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank + ) chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None @@ -296,8 +331,14 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): chunk_tensors = [None] * tp_size for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + cur_src_rank = _megatron_calc_global_rank( + tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank + ) + sync_tensor = ( + tensor + if torch.distributed.get_rank() == cur_src_rank + else buffer_tensor + ) dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) if torch.distributed.get_rank() == 0: @@ -308,20 +349,30 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): q_weight_list = [] k_weight_list = [] v_weight_list = [] - hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + hidden_size_per_head = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) if config.num_key_value_heads >= tp_size: q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + kv_size_tp = ( + hidden_size_per_head * config.num_key_value_heads // tp_size + ) total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): - num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size + num_query_groups_per_partition = ( + wrapped_models[0].config.num_query_groups // tp_size + ) qkv_part = full_tensor[i * total_size : (i + 1) * total_size] q_size_chunk = q_size_tp // num_query_groups_per_partition kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): + for qkv_part_chunk in qkv_part.chunk( + num_query_groups_per_partition + ): q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + k_part = qkv_part_chunk[ + q_size_chunk : q_size_chunk + kv_size_chunk + ] v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] q_weight_list.append(q_part) k_weight_list.append(k_part) @@ -331,13 +382,19 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): kv_size_tp = hidden_size_per_head total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): - num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size + num_query_groups_per_partition = ( + wrapped_models[0].config.num_query_groups // tp_size + ) qkv_part = full_tensor[i * total_size : (i + 1) * total_size] q_size_chunk = q_size_tp // num_query_groups_per_partition kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): + for qkv_part_chunk in qkv_part.chunk( + num_query_groups_per_partition + ): q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] + k_part = qkv_part_chunk[ + q_size_chunk : q_size_chunk + kv_size_chunk + ] v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] q_weight_list.append(q_part) if i * config.num_key_value_heads % tp_size == 0: @@ -454,12 +511,20 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): if is_value_model: lm_head_weight = None if pp_rank == pp_size - 1: - lm_head_weight = getattr(gpt_model_module.output_layer, "weight", None) - _broadcast_tensor(lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1) + lm_head_weight = getattr( + gpt_model_module.output_layer, "weight", None + ) + _broadcast_tensor( + lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1 + ) else: _broadcast_tp_shard_tensor( - getattr(gpt_model_module.output_layer, "weight", None) if pp_rank == pp_size - 1 else None, + ( + getattr(gpt_model_module.output_layer, "weight", None) + if pp_rank == pp_size - 1 + else None + ), "lm_head.weight", src_pp_rank=pp_size - 1, ) @@ -478,16 +543,22 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): def merge_megatron_ckpt_gptmodel_qwen_moe( wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False ): - raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented") + raise NotImplementedError( + "merge_megatron_ckpt_gptmodel_qwen_moe is not implemented" + ) def merge_megatron_ckpt_gptmodel_qwen2_5_vl( wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False ): - raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen2_5_vl is not implemented") + raise NotImplementedError( + "merge_megatron_ckpt_gptmodel_qwen2_5_vl is not implemented" + ) -def merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): +def merge_megatron_ckpt_gptmodel_dpskv3( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): raise NotImplementedError("merge_megatron_ckpt_gptmodel_dpskv3 is not implemented") diff --git a/Agent0/executor_train/verl/verl/models/mcore/util.py b/Agent0/executor_train/verl/verl/models/mcore/util.py index c1ef7a2..3821625 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/util.py +++ b/Agent0/executor_train/verl/verl/models/mcore/util.py @@ -41,18 +41,24 @@ def preprocess_packed_seqs( seqlens_in_batch_padded = seqlens_in_batch + pad_size cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) - cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens_padded = torch.zeros( + batch_size + 1, dtype=torch.int32, device=input_ids.device + ) cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) max_seqlen_in_batch = seqlens_in_batch_padded.max().item() shape = list(input_ids.shape[1:]) shape[0] = seqlens_in_batch_padded.sum().item() // cp_size if pre_process: - input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) + input_ids_rmpad = torch.zeros( + shape, dtype=input_ids.dtype, device=input_ids.device + ) for i in range(batch_size): if cp_size <= 1: seqlen = seqlens_in_batch[i] - input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]] + input_ids_rmpad[ + cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen + ] = input_ids[i, attention_mask[i]] continue seqlen = seqlens_in_batch_padded[i] // cp_size half_seqlen = seqlen // 2 @@ -68,9 +74,9 @@ def preprocess_packed_seqs( remain_end = min(remain_end, d.shape[0]) remain_len = remain_end - remain_start if remain_len > 0: - input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ - remain_start:remain_end - ] + input_ids_rmpad[ + start_idx + half_seqlen : start_idx + half_seqlen + remain_len + ] = d[remain_start:remain_end] packed_seq_params = PackedSeqParams( qkv_format="thd", @@ -100,7 +106,9 @@ def postprocess_packed_seqs( """ if not post_process: return output - shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim + shape = [batch_size, seq_len] + list( + output.shape[2:] + ) # 1,packed, dim -> batch_size, seq_len, dim output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) cp_size = mpu.get_context_parallel_world_size() @@ -109,7 +117,9 @@ def postprocess_packed_seqs( # output shape: [1, packed_len, hidden_dim] # need to gather across cp group and concatenate in sequence dimension output_list = [torch.empty_like(output) for _ in range(cp_size)] - torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) + torch.distributed.all_gather( + output_list, output.detach(), group=mpu.get_context_parallel_group() + ) output_list[mpu.get_context_parallel_rank()] = output else: output_list = [output] @@ -117,11 +127,15 @@ def postprocess_packed_seqs( if cp_size <= 1: s = attention_mask[i].sum().item() output_new[i, attention_mask[i]] = output[0][ - packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s + packed_seq_params.cu_seqlens_q_padded[ + i + ] : packed_seq_params.cu_seqlens_q_padded[i] + + s ] continue s_len_padded_chunk = ( - packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i] + packed_seq_params.cu_seqlens_q_padded[i + 1] + - packed_seq_params.cu_seqlens_q_padded[i] ) // cp_size half_seqlen = s_len_padded_chunk // 2 s_len = attention_mask[i].sum().item() @@ -133,10 +147,16 @@ def postprocess_packed_seqs( packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size o0, o1 = ( o[packed_start_idx : packed_start_idx + half_seqlen], - o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], + o[ + packed_start_idx + + half_seqlen : packed_start_idx + + s_len_padded_chunk + ], ) tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 - tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 + tmp[ + s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen + ] = o1 output_new[i, attention_mask[i]] = tmp[:s_len] return output_new @@ -167,11 +187,17 @@ def remove_left_padding( seq_len = seq_len + pad_size shape[1] = seq_len if pre_process: - new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) + new_input_ids = torch.zeros( + dtype=input_ids.dtype, device=input_ids.device, size=shape + ) new_attention_mask = torch.zeros( - dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len) + dtype=attention_mask.dtype, + device=attention_mask.device, + size=(batch_size, seq_len), + ) + new_position_ids = torch.zeros( + dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len) ) - new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) for i in range(batch_size): if pre_process: new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]] @@ -232,9 +258,19 @@ def postprocess_packed_seqs_for_dict_output( output.log_probs = output.log_probs.view(1, -1) output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0) ret["entropy"] = postprocess_packed_seqs( - output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + output.entropy, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, ) ret["log_probs"] = postprocess_packed_seqs( - output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process + output.log_probs, + packed_seq_params, + attention_mask, + batch_size, + seq_len, + post_process=post_process, ) return ret diff --git a/Agent0/executor_train/verl/verl/models/mcore/weight_converter.py b/Agent0/executor_train/verl/verl/models/mcore/weight_converter.py index 791513f..f71f7d1 100644 --- a/Agent0/executor_train/verl/verl/models/mcore/weight_converter.py +++ b/Agent0/executor_train/verl/verl/models/mcore/weight_converter.py @@ -27,24 +27,37 @@ def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig) self.hf_config = hf_config self.mcore_config = mcore_config - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor: + def convert_param( + self, name: str, params_one_group: list[torch.Tensor] + ) -> torch.Tensor: raise NotImplementedError class McoreToHFWeightConverterDense(McoreToHFWeightConverterBase): - def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def _convert_attention_param( + self, name: str, params: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: # 'decoder.layers.0.self_attention.linear_proj.weight' # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight' # 'decoder.layers.0.self_attention.linear_qkv.weight' # 'decoder.layers.0.self_attention.linear_qkv.bias' layer_number = name.split(".")[2] convert_names = [] - if "self_attention.linear_qkv.bias" in name or "self_attention.linear_qkv.weight" in name: + if ( + "self_attention.linear_qkv.bias" in name + or "self_attention.linear_qkv.weight" in name + ): param_type = name.split(".")[-1] assert param_type == "bias" or param_type == "weight" - convert_names.append(f"model.layers.{layer_number}.self_attn.q_proj.{param_type}") - convert_names.append(f"model.layers.{layer_number}.self_attn.k_proj.{param_type}") - convert_names.append(f"model.layers.{layer_number}.self_attn.v_proj.{param_type}") + convert_names.append( + f"model.layers.{layer_number}.self_attn.q_proj.{param_type}" + ) + convert_names.append( + f"model.layers.{layer_number}.self_attn.k_proj.{param_type}" + ) + convert_names.append( + f"model.layers.{layer_number}.self_attn.v_proj.{param_type}" + ) assert len(params) == 3 elif "self_attention.linear_proj.weight" in name: convert_names.append(f"model.layers.{layer_number}.self_attn.o_proj.weight") @@ -62,7 +75,9 @@ def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tup raise NotImplementedError(f"Unsupported parameter name: {name}") return convert_names, params - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def _convert_mlp_param( + self, name: str, params: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' # 'decoder.layers.0.mlp.linear_fc1.weight' # 'decoder.layers.0.mlp.linear_fc2.weight' @@ -74,7 +89,9 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis convert_names.append(f"model.layers.{layer_number}.mlp.up_proj.weight") assert len(params) == 2 elif "mlp.linear_fc1.layer_norm_weight" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + convert_names.append( + f"model.layers.{layer_number}.post_attention_layernorm.weight" + ) assert len(params) == 1 elif "mlp.linear_fc2.weight" in name: convert_names.append(f"model.layers.{layer_number}.mlp.down_proj.weight") @@ -83,7 +100,9 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis raise NotImplementedError(f"Unsupported parameter name: {name}") return convert_names, params - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def convert_param( + self, name: str, params_one_group: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: direct_name_mapping = { "embedding.word_embeddings.weight": "model.embed_tokens.weight", "decoder.final_layernorm.weight": "model.norm.weight", @@ -101,7 +120,9 @@ def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tupl class McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense): - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def _convert_mlp_param( + self, name: str, params: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: # 'decoder.layers.0.pre_mlp_layernorm.weight', # 'decoder.layers.0.mlp.router.weight', # 'decoder.layers.0.mlp.shared_experts.gate_weight', @@ -118,29 +139,45 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis layer_number = name.split(".")[2] convert_names = [] if "pre_mlp_layernorm" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + convert_names.append( + f"model.layers.{layer_number}.post_attention_layernorm.weight" + ) assert len(params) == 1 elif "mlp.router.weight" in name: convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") assert len(params) == 1 elif "shared_experts.gate_weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert_gate.weight") + convert_names.append( + f"model.layers.{layer_number}.mlp.shared_expert_gate.weight" + ) assert len(params) == 1 elif "shared_experts.linear_fc1.weight" in name: # split gate_proj and up_proj - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight") + convert_names.append( + f"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight" + ) + convert_names.append( + f"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight" + ) assert len(params) == 2 elif "shared_experts.linear_fc2.weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight") + convert_names.append( + f"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight" + ) assert len(params) == 1 elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + convert_names.append( + f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight" + ) + convert_names.append( + f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight" + ) assert len(params) == 2 elif "mlp.experts.linear_fc2" in name: expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + convert_names.append( + f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight" + ) assert len(params) == 1 else: raise NotImplementedError(f"Unsupported parameter name: {name}") @@ -148,7 +185,9 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis class McoreToHFWeightConverterQwen2_5_VL(McoreToHFWeightConverterDense): - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def convert_param( + self, name: str, params_one_group: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: direct_name_mapping = { "language_model.embedding.word_embeddings.weight": "model.embed_tokens.weight", "language_model.decoder.final_layernorm.weight": "model.norm.weight", @@ -170,7 +209,9 @@ def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tupl else: raise NotImplementedError(f"Unsupported parameter name: {name}") - def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def _convert_attention_param( + self, name: str, params: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: model_type, _, _, layer_number = name.split(".")[:4] convert_names = [] @@ -214,7 +255,9 @@ def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tup if "bias" in name_after_layer: convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.bias") else: - convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.weight") + convert_names.append( + f"visual.blocks.{layer_number}.attn.qkv.weight" + ) else: assert len(params) == 1 convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") @@ -222,7 +265,9 @@ def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tup raise NotImplementedError(f"Unsupported model type: {model_type}") return convert_names, params - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def _convert_mlp_param( + self, name: str, params: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: model_type, _, _, layer_number = name.split(".")[:4] convert_names = [] @@ -267,7 +312,9 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis class McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase): - def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def _convert_attention_param( + self, name: str, params: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: # mcore # 'decoder.layers.0.input_layernorm.weight' # 'decoder.layers.0.self_attention.linear_proj.weight' @@ -303,10 +350,14 @@ def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tup convert_names = [] layer_number = name.split(".")[2] name_after_layer = name.split(f".{layer_number}.")[1] - convert_names.append(f"model.layers.{layer_number}.{name_map_after_layer[name_after_layer]}") + convert_names.append( + f"model.layers.{layer_number}.{name_map_after_layer[name_after_layer]}" + ) return convert_names, params - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def _convert_mlp_param( + self, name: str, params: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: # mcore dense # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' # 'decoder.layers.0.mlp.linear_fc2.weight' @@ -367,20 +418,30 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis else: if "mlp.experts.linear_fc1.weight" in name: expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + convert_names.append( + f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight" + ) + convert_names.append( + f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight" + ) assert len(params) == 2 elif "mlp.experts.linear_fc2.weight" in name: expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + convert_names.append( + f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight" + ) assert len(params) == 1 else: raise NotImplementedError(f"Unsupported parameter name: {name}") return convert_names, params - def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - assert self.mcore_config.mtp_num_layers == 1, "only support one mtp layer for now" + def _convert_mtp_param( + self, name: str, params: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: + assert ( + self.mcore_config.mtp_num_layers == 1 + ), "only support one mtp layer for now" assert self.mcore_config.num_layers == 61, "only support 61 layers for now" direct_name_mapping = { "mtp.layers.0.enorm.weight": "model.layers.61.enorm.weight", @@ -390,7 +451,9 @@ def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis } if name in direct_name_mapping: return [direct_name_mapping[name]], [params[0]] - assert "mtp.layers.0.transformer_layer" in name, "only support transformer layer for now" + assert ( + "mtp.layers.0.transformer_layer" in name + ), "only support transformer layer for now" # use proxy name to convert proxy_name = name.replace("mtp.layers.0.transformer_layer", "decoder.layers.61") if "self_attention" in proxy_name or "input_layernorm.weight" in proxy_name: @@ -401,7 +464,9 @@ def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis raise NotImplementedError(f"Unsupported parameter name: {name}") return convert_names, params - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def convert_param( + self, name: str, params_one_group: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: direct_name_mapping = { "embedding.word_embeddings.weight": "model.embed_tokens.weight", "decoder.final_layernorm.weight": "model.norm.weight", @@ -420,7 +485,9 @@ def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tupl class McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense): - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def _convert_mlp_param( + self, name: str, params: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: # decoder.layers.0.mlp.router.weight # decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7 # decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7 @@ -428,23 +495,35 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis layer_number = name.split(".")[2] convert_names = [] if "pre_mlp_layernorm" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + convert_names.append( + f"model.layers.{layer_number}.post_attention_layernorm.weight" + ) elif "mlp.router.weight" in name: - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.gate.weight") + convert_names.append( + f"model.layers.{layer_number}.block_sparse_moe.gate.weight" + ) elif "mlp.experts.linear_fc1.weight" in name: expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight") - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight") + convert_names.append( + f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight" + ) + convert_names.append( + f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight" + ) elif "mlp.experts.linear_fc2.weight" in name: expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight") + convert_names.append( + f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight" + ) else: raise NotImplementedError(f"Unsupported parameter name: {name}") return convert_names, params class McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense): - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: + def _convert_mlp_param( + self, name: str, params: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: # qwen3 moe no share expert # 'decoder.layers.0.pre_mlp_layernorm.weight', @@ -460,19 +539,27 @@ def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[lis layer_number = name.split(".")[2] convert_names = [] if "pre_mlp_layernorm" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") + convert_names.append( + f"model.layers.{layer_number}.post_attention_layernorm.weight" + ) assert len(params) == 1 elif "mlp.router.weight" in name: convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") assert len(params) == 1 elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") + convert_names.append( + f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight" + ) + convert_names.append( + f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight" + ) assert len(params) == 2 elif "mlp.experts.linear_fc2" in name: expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") + convert_names.append( + f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight" + ) assert len(params) == 1 else: raise NotImplementedError(f"Unsupported parameter name: {name}") diff --git a/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py b/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py index 3168635..d6db5d9 100644 --- a/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py +++ b/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py @@ -39,7 +39,8 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + + pp_rank_idx * num_layers_per_model ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( @@ -51,7 +52,12 @@ def _megatron_calc_layer_map(config): def load_state_dict_to_megatron_qwen2( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False + state_dict, + wrapped_models, + config, + params_dtype, + is_value_model=False, + tie_word_embeddings=False, ): """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP @@ -70,7 +76,9 @@ def _get_gpt_model(model): def fetch_params(module): for param in module.parameters(): torch.distributed.fetch( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + param.data, + src=mpu.get_data_parallel_src_rank(), + group=mpu.get_data_parallel_group(), ) dp_rank = mpu.get_data_parallel_rank() @@ -89,7 +97,9 @@ def fetch_params(module): assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + assert ( + num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + ), ( f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" ) @@ -107,7 +117,9 @@ def _fetch_tensor(tensor, name) -> torch.Tensor: if tensor is not None: tensor = tensor.data.copy_(state_dict[name], non_blocking=True) - def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + def _fetch_tp_shard_tensor_vocab( + tensor, name, chunk_dim=0, mutate_func=None + ) -> torch.Tensor: """fetch tensor in tp shards""" nonlocal state_dict tp_rank = mpu.get_tensor_model_parallel_rank() @@ -123,7 +135,9 @@ def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> else: print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + def _fetch_tp_shard_tensor( + tensor, name, chunk_dim=0, mutate_func=None + ) -> torch.Tensor: """fetch tensor in tp shards""" nonlocal state_dict tp_rank = mpu.get_tensor_model_parallel_rank() @@ -149,23 +163,34 @@ def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + config.intermediate_size * 2, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) + gate_weight_tp = gate_weight[ + i * intermediate_size_tp : (i + 1) * intermediate_size_tp + ] + up_weight_tp = up_weight[ + i * intermediate_size_tp : (i + 1) * intermediate_size_tp + ] + new_gate_up_weight[ + intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1) + ].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) if tensor is not None: tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) else: - print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") + print( + f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading" + ) - def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + def _fetch_tp_shard_tensor_qkv( + tensor, q_name, k_name, v_name, bias=False + ) -> torch.Tensor: """fetch tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -184,15 +209,22 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to total_size = q_size_tp + 2 * kv_size_tp if not bias: new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + total_size * tp_size, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + new_weight_qkv = torch.empty( + total_size * tp_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) else: q_size_tp = config.hidden_size // tp_size @@ -200,17 +232,28 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to total_size = q_size_tp + 2 * kv_size_tp if not bias: new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + total_size * tp_size, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + new_weight_qkv = torch.empty( + total_size * tp_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + start_idx = ( + i * config.num_key_value_heads // tp_size * hidden_size_per_head + ) + end_idx = ( + i * config.num_key_value_heads // tp_size + 1 + ) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) + new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( + torch.cat([q_part, k_part, v_part], dim=0) + ) tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) if tensor is not None: @@ -238,9 +281,10 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to for vpp_rank in range(vpp_size): num_layer_vpp_chunk = num_layer_per_pp // vpp_size num_layer_this_model = num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( - mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk - ) + offset = vpp_rank * ( + config.num_hidden_layers + // mpu.get_virtual_pipeline_model_parallel_world_size() + ) + (mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk) layer_list.extend(list(range(offset, offset + num_layer_this_model))) else: num_layer_this_model = num_layer_per_pp @@ -287,7 +331,11 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to ) _fetch_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.post_attention_layernorm.weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.post_attention_layernorm.weight", ) @@ -319,10 +367,16 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to lm_head_weight = gpt_model_module.lm_head.weight if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + if ( + "lm_head.weight" in state_dict + and state_dict["lm_head.weight"].shape[0] == 1 + ): _fetch_tensor(lm_head_weight, "lm_head.weight") print_rank_0("load lm_head from value_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + elif ( + "reward_head.weight" in state_dict + and state_dict["reward_head.weight"].shape[0] == 1 + ): _fetch_tensor(lm_head_weight, "reward_head.weight") print_rank_0("load lm_head from value_head weight") else: @@ -334,4 +388,6 @@ def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> to dist.barrier() get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") + print_rank_0( + f"loading megatron ckpt done, time elapsed {time.time() - start_time}s" + ) diff --git a/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py b/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py index 770e365..fd5fe55 100644 --- a/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py +++ b/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py @@ -39,7 +39,8 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + + pp_rank_idx * num_layers_per_model ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( @@ -51,7 +52,12 @@ def _megatron_calc_layer_map(config): def load_state_dict_to_megatron_qwen2( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False + state_dict, + wrapped_models, + config, + params_dtype, + is_value_model=False, + tie_word_embeddings=False, ): """Load merged state_dict to sharded Megatron module in training.""" from megatron.core import DistributedDataParallel as LocalDDP @@ -70,7 +76,9 @@ def _get_gpt_model(model): def broadcast_params(module): for param in module.parameters(): torch.distributed.broadcast( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() + param.data, + src=mpu.get_data_parallel_src_rank(), + group=mpu.get_data_parallel_group(), ) dp_rank = mpu.get_data_parallel_rank() @@ -89,7 +97,9 @@ def broadcast_params(module): assert len(wrapped_models) == virtual_pp_size num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( + assert ( + num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers + ), ( f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" ) @@ -135,7 +145,9 @@ def _broadcast_tensor(tensor, name) -> torch.Tensor: tensor.data.copy_(weight) dist.broadcast(tensor, src=0, group=mp_group) - def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + def _broadcast_tp_shard_tensor_vocab( + tensor, name, chunk_dim=0, mutate_func=None + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -171,10 +183,12 @@ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None requires_grad=False, ) else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert ( + tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -183,7 +197,9 @@ def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None if (i == tp_rank) and (tensor is not None): tensor.data.copy_(sync_tensor) - def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: + def _broadcast_tp_shard_tensor( + tensor, name, chunk_dim=0, mutate_func=None + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -218,10 +234,12 @@ def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> t requires_grad=False, ) else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + assert ( + tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -241,15 +259,22 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens gate_weight = state_dict[gate_name] up_weight = state_dict[up_name] new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() + config.intermediate_size * 2, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) for i in range(tp_size): intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) + gate_weight_tp = gate_weight[ + i * intermediate_size_tp : (i + 1) * intermediate_size_tp + ] + up_weight_tp = up_weight[ + i * intermediate_size_tp : (i + 1) * intermediate_size_tp + ] + new_gate_up_weight[ + intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1) + ].copy_(torch.cat([gate_weight_tp, up_weight_tp], dim=0)) tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) chunk_shape = tensor_chunk[0].shape @@ -261,7 +286,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens chunk_shape = obj_list[0] if chunk_shape is None: # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") + print_rank_0( + f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading" + ) return if tensor is None: @@ -276,7 +303,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " f"{tensor.shape} != {chunk_shape}" ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False + ) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -285,7 +314,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tens if (i == tp_rank) and (tensor is not None): tensor.data.copy_(sync_tensor) - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: + def _broadcast_tp_shard_tensor_qkv( + tensor, q_name, k_name, v_name, bias=False + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -293,7 +324,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - tp_size = mpu.get_tensor_model_parallel_world_size() if torch.distributed.get_rank() == 0: - assert q_name in state_dict and k_name in state_dict and v_name in state_dict + assert ( + q_name in state_dict and k_name in state_dict and v_name in state_dict + ) full_weight_q = state_dict[q_name] full_weight_k = state_dict[k_name] full_weight_v = state_dict[v_name] @@ -302,14 +335,21 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - if config.num_key_value_heads >= tp_size: q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + kv_size_tp = ( + hidden_size_per_head * config.num_key_value_heads // tp_size + ) total_size = q_size_tp + 2 * kv_size_tp if not bias: new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + total_size * tp_size, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + new_weight_qkv = torch.empty( + total_size * tp_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] @@ -324,14 +364,23 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - total_size = q_size_tp + 2 * kv_size_tp if not bias: new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() + total_size * tp_size, + config.hidden_size, + dtype=params_dtype, + device=get_device_id(), ) else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) + new_weight_qkv = torch.empty( + total_size * tp_size, dtype=params_dtype, device=get_device_id() + ) for i in range(tp_size): q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head + start_idx = ( + i * config.num_key_value_heads // tp_size * hidden_size_per_head + ) + end_idx = ( + i * config.num_key_value_heads // tp_size + 1 + ) * hidden_size_per_head k_part = full_weight_k[start_idx:end_idx] v_part = full_weight_v[start_idx:end_idx] new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( @@ -348,7 +397,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - chunk_shape = obj_list[0] if chunk_shape is None: # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") + print_rank_0( + f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading" + ) return if tensor is None: @@ -359,10 +410,12 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - requires_grad=False, ) else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + assert ( + tensor.shape == chunk_shape + ), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" + sync_tensor = torch.empty_like( + tensor, device=get_device_id(), requires_grad=False ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) for i in range(tp_size): if torch.distributed.get_rank() == 0: @@ -379,7 +432,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - embed_tokens_weight = None if pp_rank == 0: embed_tokens_weight = gpt_model_module.model.embed_tokens.weight - _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") + _broadcast_tp_shard_tensor_vocab( + embed_tokens_weight, "model.embed_tokens.weight" + ) # Transformer layers # ------------------- @@ -399,7 +454,11 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - ) _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.self_attn.qkv_proj.weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.self_attn.q_proj.weight", f"{layer_name}.self_attn.k_proj.weight", f"{layer_name}.self_attn.v_proj.weight", @@ -420,7 +479,11 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - ) _broadcast_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, + ( + sync_layer.post_attention_layernorm.weight + if dst_pp_rank == pp_rank + else None + ), f"{layer_name}.post_attention_layernorm.weight", ) @@ -453,10 +516,16 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - lm_head_weight = gpt_model_module.lm_head.weight if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: + if ( + "lm_head.weight" in state_dict + and state_dict["lm_head.weight"].shape[0] == 1 + ): _broadcast_tensor(lm_head_weight, "lm_head.weight") print_rank_0("load lm_head from value_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: + elif ( + "reward_head.weight" in state_dict + and state_dict["reward_head.weight"].shape[0] == 1 + ): _broadcast_tensor(lm_head_weight, "reward_head.weight") print_rank_0("load lm_head from value_head weight") else: @@ -472,4 +541,6 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) - broadcast_params(wrapped_model) get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") + print_rank_0( + f"loading megatron ckpt done, time elapsed {time.time() - start_time}s" + ) diff --git a/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py b/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py index 737f73b..23facd1 100644 --- a/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py +++ b/Agent0/executor_train/verl/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py @@ -32,9 +32,9 @@ def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int tp_size = mpu.get_tensor_model_parallel_world_size() dp_size = mpu.get_data_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() - assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( - f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" - ) + assert ( + tp_size * dp_size * pp_size == torch.distributed.get_world_size() + ), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" # We only support TP-DP-PP grouping, for correctness when resharding return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank @@ -58,7 +58,8 @@ def _megatron_calc_layer_map(config): for pp_rank_idx in range(pp_size): for virtual_pp_rank_idx in range(virtual_pp_size): layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model + virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + + pp_rank_idx * num_layers_per_model ) for layer_idx in range(num_layers_per_model): layer_map[layer_offset + layer_idx] = ( @@ -69,7 +70,9 @@ def _megatron_calc_layer_map(config): return layer_map -def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): +def merge_megatron_ckpt_qwen2( + wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False +): """Merge sharded parameters of a Megatron module into a merged checkpoint. Args: @@ -111,10 +114,10 @@ def _get_gpt_model(model): for i, wrapped_model in enumerate(wrapped_models): models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].model.layers) == num_layers_per_model, ( - "len model layers {} not equal to num_layers_per_model {}".format( - len(models[i].model.layers), num_layers_per_model - ) + assert ( + len(models[i].model.layers) == num_layers_per_model + ), "len model layers {} not equal to num_layers_per_model {}".format( + len(models[i].model.layers), num_layers_per_model ) state_dict = dict() @@ -165,7 +168,9 @@ def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: if torch.distributed.get_rank() == 0: state_dict[name] = _get_cpu_tensor(weight) - def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: + def _broadcast_tp_shard_tensor( + tensor, name, src_pp_rank, concat_dim=0, mutate_func=None + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -192,8 +197,14 @@ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_f chunk_tensors = [None] * tp_size for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + cur_src_rank = _megatron_calc_global_rank( + tp_rank=i, dp_rank=0, pp_rank=src_pp_rank + ) + sync_tensor = ( + tensor + if torch.distributed.get_rank() == cur_src_rank + else buffer_tensor + ) dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) if torch.distributed.get_rank() == 0: @@ -205,7 +216,9 @@ def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_f full_tensor = mutate_func(full_tensor) state_dict[name] = full_tensor - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: + def _broadcast_tp_shard_tensor_gate_up( + tensor, gate_name, up_name, src_pp_rank + ) -> torch.Tensor: """broadcast tensor in tp shards across mp_group""" nonlocal state_dict nonlocal mp_group @@ -219,7 +232,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) chunk_shape = obj_list[0] if chunk_shape is None: # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") + print_rank_0( + f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting" + ) return buffer_tensor = torch.empty( @@ -232,8 +247,14 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) chunk_tensors = [None] * tp_size for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + cur_src_rank = _megatron_calc_global_rank( + tp_rank=i, dp_rank=0, pp_rank=src_pp_rank + ) + sync_tensor = ( + tensor + if torch.distributed.get_rank() == cur_src_rank + else buffer_tensor + ) dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) if torch.distributed.get_rank() == 0: @@ -245,7 +266,9 @@ def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) gate_weight_list = [] up_weight_list = [] for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_up_weight_tp = full_tensor[ + intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1) + ] gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] gate_weight_list.append(gate_weight_tp) @@ -281,8 +304,14 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): chunk_tensors = [None] * tp_size for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor + cur_src_rank = _megatron_calc_global_rank( + tp_rank=i, dp_rank=0, pp_rank=src_pp_rank + ) + sync_tensor = ( + tensor + if torch.distributed.get_rank() == cur_src_rank + else buffer_tensor + ) dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) if torch.distributed.get_rank() == 0: @@ -297,7 +326,9 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): if config.num_key_value_heads >= tp_size: q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size + kv_size_tp = ( + hidden_size_per_head * config.num_key_value_heads // tp_size + ) total_size = q_size_tp + 2 * kv_size_tp for i in range(tp_size): qkv_part = full_tensor[i * total_size : (i + 1) * total_size] @@ -422,16 +453,23 @@ def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): src_pp_rank=pp_size - 1, ) _broadcast_tensor( - gpt_model_module.reward_head.weight - if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None - else None, + ( + gpt_model_module.reward_head.weight + if pp_rank == pp_size - 1 + and getattr(gpt_model_module, "reward_weight", None) is not None + else None + ), "reward_head.weight", src_pp_rank=pp_size - 1, ) else: _broadcast_tp_shard_tensor( - getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, + ( + getattr(gpt_model_module.lm_head, "weight", None) + if pp_rank == pp_size - 1 + else None + ), "lm_head.weight", src_pp_rank=pp_size - 1, ) diff --git a/Agent0/executor_train/verl/verl/models/qwen2/megatron/layers/parallel_attention.py b/Agent0/executor_train/verl/verl/models/qwen2/megatron/layers/parallel_attention.py index 702c429..32b2d22 100644 --- a/Agent0/executor_train/verl/verl/models/qwen2/megatron/layers/parallel_attention.py +++ b/Agent0/executor_train/verl/verl/models/qwen2/megatron/layers/parallel_attention.py @@ -46,17 +46,23 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -78,13 +84,22 @@ def forward(self, x, seq_len=None): class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding): """Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) t = t / self.scaling_factor freqs = torch.einsum("i,j->ij", t, self.inv_freq) @@ -97,7 +112,14 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): """Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) @@ -106,12 +128,17 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): if seq_len > self.max_position_embeddings: base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -143,7 +170,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -164,9 +193,9 @@ def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): # assign values after tp tp_size = mpu.get_tensor_model_parallel_world_size() - assert self.num_heads % tp_size == 0, ( - f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" - ) + assert ( + self.num_heads % tp_size == 0 + ), f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" assert self.num_key_value_heads % tp_size == 0, ( f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" f"{self.num_key_value_heads}, tp_size={tp_size}" @@ -228,7 +257,11 @@ def _init_rope(self): ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) def forward( self, @@ -238,20 +271,32 @@ def forward( ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + query_states, key_states, value_states = qkv.split( + [self.q_size, self.k_size, self.v_size], dim=-1 + ) - query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads_per_tp, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim + ).transpose(1, 2) kv_seq_len = key_states.shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): raise ValueError( @@ -267,7 +312,9 @@ def forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): @@ -292,7 +339,9 @@ def forward( def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): batch_size = position_ids.shape[0] - q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) + q = pad_input( + q, indices, batch_size, sequence_length + ) # (batch_size, seqlen, num_head, head_dim) k = pad_input(k, indices, batch_size, sequence_length) cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] @@ -309,10 +358,22 @@ def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_l # cos/sin shoudl be: (seq_length, rotary_dim / 2) def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): q_embed = apply_rotary_emb( - q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + q, + cos, + sin, + interleaved=False, + inplace=False, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, ) k_embed = apply_rotary_emb( - k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + k, + cos, + sin, + interleaved=False, + inplace=False, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, ) return q_embed, k_embed @@ -327,7 +388,9 @@ def forward( cu_seqlens: torch.Tensor = None, max_seqlen_in_batch: int = None, ): - total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel + total_nnz, _, _ = ( + hidden_states.size() + ) # This is the total_nnz padded after sequence parallel if self.megatron_config.sequence_parallel: total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() @@ -347,14 +410,28 @@ def forward( # Flash attention requires the input to have the shape # batch_size x seq_length x head_dime x hidden_dim # therefore we just need to keep the original shape - query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) - key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) - value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) + query_states = query_states.view( + total_nnz, self.num_heads_per_tp, self.head_dim + ) + key_states = key_states.view( + total_nnz, self.num_key_value_heads_per_tp, self.head_dim + ) + value_states = value_states.view( + total_nnz, self.num_key_value_heads_per_tp, self.head_dim + ) cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) - cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half + cos, sin = ( + cos[:, : cos.shape[1] // 2], + sin[:, : sin.shape[1] // 2], + ) # flash attn only needs half query_states, key_states = apply_rotary_pos_emb_rmpad_flash( - query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch + query_states, + key_states, + cos, + sin, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen_in_batch, ) # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, # position_ids, indices, @@ -388,12 +465,16 @@ def forward( ) attn_output_unpad = attn_output_unpad.to(input_dtype) - attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() + attn_output_unpad = attn_output_unpad.reshape( + total_nnz, 1, self.hidden_size_per_tp + ).contiguous() # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled # Here we need to repad if self.megatron_config.sequence_parallel: - attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) + attn_output_unpad = F.pad( + attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad) + ) attn_output_unpad = self.o_proj(attn_output_unpad)[0] return attn_output_unpad diff --git a/Agent0/executor_train/verl/verl/models/qwen2/megatron/layers/parallel_decoder.py b/Agent0/executor_train/verl/verl/models/qwen2/megatron/layers/parallel_decoder.py index 3c8a2a6..44705db 100644 --- a/Agent0/executor_train/verl/verl/models/qwen2/megatron/layers/parallel_decoder.py +++ b/Agent0/executor_train/verl/verl/models/qwen2/megatron/layers/parallel_decoder.py @@ -33,12 +33,16 @@ class ParallelQwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): + def __init__( + self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int + ): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) self.layer_idx = layer_idx self.hidden_size = config.hidden_size - self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config) + self.self_attn = ParallelQwen2Attention( + config=config, megatron_config=megatron_config + ) self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) @@ -49,7 +53,9 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -100,12 +106,16 @@ def forward( class ParallelQwen2DecoderLayerRmPad(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): + def __init__( + self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int + ): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) self.hidden_size = config.hidden_size self.layer_idx = layer_idx - self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config) + self.self_attn = ParallelQwen2AttentionRmPad( + config=config, megatron_config=megatron_config + ) self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) @@ -119,7 +129,9 @@ def forward( indices: torch.Tensor = None, cu_seqlens: int = None, max_seqlen_in_batch: int = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: residual = hidden_states # (total_nnz // sp, 1, hidden_size) hidden_states = self.input_layernorm(hidden_states) diff --git a/Agent0/executor_train/verl/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/Agent0/executor_train/verl/verl/models/qwen2/megatron/modeling_qwen2_megatron.py index 92e81be..64ce701 100644 --- a/Agent0/executor_train/verl/verl/models/qwen2/megatron/modeling_qwen2_megatron.py +++ b/Agent0/executor_train/verl/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -34,7 +34,11 @@ from verl.utils.megatron import tensor_parallel as tp_utils from verl.utils.megatron_utils import TransformerConfig, convert_config -from .layers import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad, ParallelQwen2RMSNorm +from .layers import ( + ParallelQwen2DecoderLayer, + ParallelQwen2DecoderLayerRmPad, + ParallelQwen2RMSNorm, +) """ TODO: @@ -45,7 +49,9 @@ # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device +): """ Make causal mask used for bi-directional self-attention. """ @@ -69,7 +75,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) class ParallelQwen2Model(nn.Module): @@ -87,19 +95,28 @@ def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): self.vocab_size = config.vocab_size embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + assert embedding_kwargs.get( + "config", False + ), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config) self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + **embedding_kwargs, ) self.layers = nn.ModuleList( - [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] + [ + ParallelQwen2DecoderLayer(config, megatron_config) + for _ in range(config.num_hidden_layers) + ] ) self.norm = ParallelQwen2RMSNorm(config, megatron_config) # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds + ): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None @@ -112,11 +129,13 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask @@ -141,7 +160,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) # embed positions - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds + ) hidden_states = inputs_embeds @@ -237,14 +258,21 @@ def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() self.megatron_config = megatron_config if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + assert embedding_kwargs.get( + "config", False + ), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + **embedding_kwargs, ) self.layers = nn.ModuleList( - [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] + [ + ParallelQwen2DecoderLayerRmPad(config, megatron_config) + for _ in range(config.num_hidden_layers) + ] ) self.norm = ParallelQwen2RMSNorm(config, megatron_config) @@ -266,12 +294,16 @@ def forward( Returns: """ - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + inputs_embeds = self.embed_tokens( + input_ids + ) # (1, total_nnz) -> (1, total_nnz, hidden_size) # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) inputs_embeds = inputs_embeds.transpose(0, 1) if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region( + inputs_embeds + ) hidden_states = inputs_embeds for idx, decoder_layer in enumerate(self.layers): @@ -318,7 +350,9 @@ def _forward_head(self, hidden_states): # all_gather from sequence parallel region is performed inside lm_head logits = self.lm_head(hidden_states)[0] logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) + logits = tensor_parallel.gather_from_tensor_model_parallel_region( + logits + ) # (total_nnz_padded, 1, vocab_size) return logits def forward( @@ -389,7 +423,9 @@ def _init_head(self, config): if self.megatron_config is not None: assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + self.lm_head = nn.Linear( + in_features=config.hidden_size, out_features=1, bias=False + ) # lm_head is effectively the same as sequence parallel sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) @@ -397,7 +433,9 @@ def _forward_head(self, hidden_states): logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) logits = logits.float() if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + logits = tensor_parallel.gather_from_sequence_parallel_region( + logits, tensor_parallel_output_grad=False + ) return logits def forward( @@ -426,7 +464,13 @@ class ParallelQwen2ModelRmPadPP(nn.Module): config: Qwen2Config """ - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process): + def __init__( + self, + config: Qwen2Config, + megatron_config: ModelParallelConfig, + pre_process, + post_process, + ): super().__init__() self.config: TransformerConfig = convert_config(config, megatron_config) self.padding_idx = config.pad_token_id @@ -436,11 +480,15 @@ def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pr self.megatron_config = megatron_config embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" + assert embedding_kwargs.get( + "config", False + ), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) if pre_process: self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + **embedding_kwargs, ) else: self.embed_tokens = None @@ -454,14 +502,18 @@ def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pr if vpp_size is not None: self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size self.num_layer_this_model = self.num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) + offset = vpp_rank * (config.num_hidden_layers // vpp_size) + ( + pp_rank * self.num_layer_vpp_chunk + ) else: self.num_layer_this_model = self.num_layer_per_pp offset = pp_rank * self.num_layer_per_pp self.layers = nn.ModuleList() for i in range(self.num_layer_this_model): - layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config, layer_idx=i + offset) + layer = ParallelQwen2DecoderLayerRmPad( + config, megatron_config, layer_idx=i + offset + ) self.layers.add_module(f"{i}", layer) if post_process: @@ -498,14 +550,18 @@ def forward( """ if self.pre_process: - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) + inputs_embeds = self.embed_tokens( + input_ids + ) # (1, total_nnz) -> (1, total_nnz, hidden_size) # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron # so need to deal with it by handle here: # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) inputs_embeds = inputs_embeds.transpose(0, 1) if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) + inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region( + inputs_embeds + ) hidden_states = inputs_embeds else: @@ -543,7 +599,10 @@ def __init__( self.config: TransformerConfig = convert_config(config, megatron_config) self.megatron_config = megatron_config self.model = ParallelQwen2ModelRmPadPP( - config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process + config, + megatron_config=megatron_config, + pre_process=pre_process, + post_process=post_process, ) self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.vocab_size = config.vocab_size @@ -576,7 +635,8 @@ def _init_head(self, config): bias=False, gather_output=False, skip_bias_add=False, - skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, + skip_weight_param_allocation=self.pre_process + and self.share_embeddings_and_output_weights, **column_kwargs, ) @@ -603,7 +663,11 @@ def setup_embeddings_and_output_layer(self) -> None: self.shared_embedding_or_output_weight().zero_out_wgrad = True return - if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: + if ( + parallel_state.is_pipeline_first_stage() + and self.pre_process + and not self.post_process + ): self.shared_embedding_or_output_weight().shared_embedding = True if self.post_process and not self.pre_process: @@ -614,10 +678,15 @@ def setup_embeddings_and_output_layer(self) -> None: self.lm_head.weight.shared = True self.lm_head.weight.shared_embedding = True - if torch.distributed.is_initialized() and parallel_state.is_rank_in_embedding_group(): + if ( + torch.distributed.is_initialized() + and parallel_state.is_rank_in_embedding_group() + ): weight = self.shared_embedding_or_output_weight() weight.data = weight.data.to(get_device_name()) - torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group()) + torch.distributed.all_reduce( + weight.data, group=parallel_state.get_embedding_group() + ) def shared_embedding_or_output_weight(self) -> torch.Tensor: if self.pre_process: @@ -683,7 +752,9 @@ def forward( if self.post_process: hidden_states = outputs logits = self._forward_head(hidden_states) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + logits = torch.squeeze( + logits, dim=1 + ) # remove the artificial batch dimension # torch.Size([8, 32, 16]) # remove padding from sequence parallel if self.megatron_config.sequence_parallel: @@ -711,7 +782,9 @@ def _init_head(self, config): if self.megatron_config is not None: assert column_kwargs.get("config", False), "must have ModelParallelConfig" tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) + self.lm_head = nn.Linear( + in_features=config.hidden_size, out_features=1, bias=False + ) # lm_head is effectively the same as sequence parallel sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) @@ -719,7 +792,9 @@ def _forward_head(self, hidden_states): logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) logits = logits.float() if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) + logits = tensor_parallel.gather_from_sequence_parallel_region( + logits, tensor_parallel_output_grad=False + ) return logits def forward( @@ -729,7 +804,11 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> tuple | CausalLMOutputWithPast: - output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) if self.post_process: output.logits = torch.squeeze(output.logits, dim=-1) return output diff --git a/Agent0/executor_train/verl/verl/models/registry.py b/Agent0/executor_train/verl/verl/models/registry.py index 829b9e2..89b7e0d 100644 --- a/Agent0/executor_train/verl/verl/models/registry.py +++ b/Agent0/executor_train/verl/verl/models/registry.py @@ -22,15 +22,27 @@ _MODELS = { "LlamaForCausalLM": ( "llama", - ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad"), + ( + "ParallelLlamaForCausalLMRmPadPP", + "ParallelLlamaForValueRmPadPP", + "ParallelLlamaForCausalLMRmPad", + ), ), "Qwen2ForCausalLM": ( "qwen2", - ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad"), + ( + "ParallelQwen2ForCausalLMRmPadPP", + "ParallelQwen2ForValueRmPadPP", + "ParallelQwen2ForCausalLMRmPad", + ), ), "MistralForCausalLM": ( "mistral", - ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", "ParallelMistralForCausalLMRmPad"), + ( + "ParallelMistralForCausalLMRmPadPP", + "ParallelMistralForValueRmPadPP", + "ParallelMistralForCausalLMRmPad", + ), ), } @@ -50,7 +62,9 @@ def load_model_cls(model_arch: str, value=False) -> Optional[type[nn.Module]]: elif value: # critic/rm model_cls_name = model_cls_name[1] - module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron") + module = importlib.import_module( + f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron" + ) return getattr(module, model_cls_name, None) @staticmethod diff --git a/Agent0/executor_train/verl/verl/models/transformers/dense_common.py b/Agent0/executor_train/verl/verl/models/transformers/dense_common.py index 56fe293..73855c9 100644 --- a/Agent0/executor_train/verl/verl/models/transformers/dense_common.py +++ b/Agent0/executor_train/verl/verl/models/transformers/dense_common.py @@ -46,9 +46,15 @@ def forward_base_model( This function should be generic enough for all pure text models. ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -111,7 +117,9 @@ def forward_with_torch_backend( elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: - raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + raise RuntimeError( + "To use forward_with_torch_backend, either labels or input_ids must be provided." + ) fused_linear_for_ppo = FusedLinearForPPO() log_probs, entropy = fused_linear_for_ppo.forward( @@ -174,7 +182,9 @@ def forward_with_triton_backend( elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: - raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + raise RuntimeError( + "To use forward_with_triton_backend, either labels or input_ids must be provided." + ) log_probs, entropy = linear_cross_entropy( hidden_states, diff --git a/Agent0/executor_train/verl/verl/models/transformers/kimi_vl.py b/Agent0/executor_train/verl/verl/models/transformers/kimi_vl.py index edd7936..32f1796 100644 --- a/Agent0/executor_train/verl/verl/models/transformers/kimi_vl.py +++ b/Agent0/executor_train/verl/verl/models/transformers/kimi_vl.py @@ -80,7 +80,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -106,7 +108,9 @@ def _ulysses_flash_attn_forward( # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) @@ -114,14 +118,18 @@ def _ulysses_flash_attn_forward( .transpose(1, 2) ) - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) # patch ulysses_sp_size = get_ulysses_sequence_parallel_world_size() if ulysses_sp_size > 1: validate_ulysses_config(self.num_heads, ulysses_sp_size) - num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads + num_key_value_groups = ( + self.config.num_attention_heads // self.config.num_key_value_heads + ) k_pe = repeat_kv(k_pe, ulysses_sp_size) # to keep heads=1 after a2a k_nope = repeat_kv(k_nope, num_key_value_groups) value_states = repeat_kv(value_states, num_key_value_groups) @@ -135,15 +143,21 @@ def _ulysses_flash_attn_forward( else: full_q_len = q_len - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) cos, sin = self.rotary_emb(value_states, seq_len=full_q_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - query_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) + query_states = k_pe.new_empty( + bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim + ) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - key_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) + key_states = k_pe.new_empty( + bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim + ) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe @@ -179,7 +193,9 @@ def _ulysses_flash_attn_forward( if self.q_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous() + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, None diff --git a/Agent0/executor_train/verl/verl/models/transformers/llama.py b/Agent0/executor_train/verl/verl/models/transformers/llama.py index 687ceab..56b279a 100644 --- a/Agent0/executor_train/verl/verl/models/transformers/llama.py +++ b/Agent0/executor_train/verl/verl/models/transformers/llama.py @@ -46,7 +46,9 @@ def llama_flash_attn_forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """ @@ -65,9 +67,15 @@ def llama_flash_attn_forward( # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) # trade off: repeat first and then all to all # key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -101,7 +109,9 @@ def llama_flash_attn_forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # TODO: These transpose are quite inefficient but Flash Attention requires the layout # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -184,9 +194,15 @@ def llama_attn_forward( bsz, q_len, _ = hidden_states.shape - query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + query_states = ( + self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + ) ########## AlltoAll for Ulysses ########## ulysses_sp_size = get_ulysses_sequence_parallel_world_size() @@ -206,18 +222,24 @@ def llama_attn_forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + if self.config._attn_implementation == "sdpa" and kwargs.get( + "output_attentions", False + ): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " "Falling back to eager attention. This warning can be removed using the argument " '`attn_implementation="eager"` when loading the model.' ) else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] attn_output, attn_weights = attention_interface( self, diff --git a/Agent0/executor_train/verl/verl/models/transformers/monkey_patch.py b/Agent0/executor_train/verl/verl/models/transformers/monkey_patch.py index d6be65a..b4a460b 100644 --- a/Agent0/executor_train/verl/verl/models/transformers/monkey_patch.py +++ b/Agent0/executor_train/verl/verl/models/transformers/monkey_patch.py @@ -43,7 +43,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, slen, num_key_value_heads, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim) + hidden_states = hidden_states[:, :, :, None, :].expand( + batch, slen, num_key_value_heads, n_rep, head_dim + ) return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim) @@ -71,7 +73,9 @@ def _ulysses_flash_attention_forward( ########## AlltoAll for Ulysses ########## if ulysses_sp_size > 1: - assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism" + assert ( + position_ids is not None + ), "position_ids is required for Ulysses sequence parallelism" # NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k, # we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA. @@ -93,13 +97,22 @@ def _ulysses_flash_attention_forward( # https://github.com/huggingface/transformers/pull/33932 # (bsz, seq_len/n) -> (bsz, seq_len) - position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)] - torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group()) + position_ids_list = [ + torch.empty_like(position_ids) for _ in range(ulysses_sp_size) + ] + torch.distributed.all_gather( + position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group() + ) position_ids = torch.concat(position_ids_list, dim=-1) # (bsz, seq_len, n_head/n, head_dim) attn_output = _flash_attention_forward( - query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs + query_states, + key_states, + value_states, + *args, + position_ids=position_ids, + **kwargs, ) ########## AlltoAll for Ulysses ########## @@ -129,7 +142,9 @@ def ulysses_wrapped_decoder_forward(self, *args, **kwargs): and getattr(self, "_needs_initial_slice", True) ) if slice_now: - call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False) + call_kwargs["inputs_embeds"] = slice_input_tensor( + inputs_embeds, dim=1, padding=False + ) self._needs_initial_slice = False try: return original_forward(self, *args, **call_kwargs) @@ -167,17 +182,26 @@ def patch_forward_with_backends( forward_with_torch_backend_function = model.__class__.forward forward_with_triton_backend_function = model.__class__.forward if model.config.model_type == "qwen2_5_vl": - from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend + from verl.models.transformers.qwen2_5_vl import ( + forward_with_torch_backend, + forward_with_triton_backend, + ) forward_with_torch_backend_function = forward_with_torch_backend forward_with_triton_backend_function = forward_with_triton_backend elif model.config.model_type == "qwen2_vl": - from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend + from verl.models.transformers.qwen2_vl import ( + forward_with_torch_backend, + forward_with_triton_backend, + ) forward_with_torch_backend_function = forward_with_torch_backend forward_with_triton_backend_function = forward_with_triton_backend else: - from verl.models.transformers.dense_common import forward_with_torch_backend, forward_with_triton_backend + from verl.models.transformers.dense_common import ( + forward_with_torch_backend, + forward_with_triton_backend, + ) forward_with_torch_backend_function = forward_with_torch_backend forward_with_triton_backend_function = forward_with_triton_backend @@ -189,7 +213,9 @@ def patch_forward_with_backends( model.__class__.forward = forward_with_torch_backend_function print(f"Using Torch backend for fused kernels in {model.__class__.__name__}") else: - raise ValueError(f"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'.") + raise ValueError( + f"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'." + ) def apply_monkey_patch( @@ -210,17 +236,23 @@ def apply_monkey_patch( module = sys.modules[model.__module__] try: - num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads + num_attention_heads, num_key_value_heads = ( + model.config.num_attention_heads, + model.config.num_key_value_heads, + ) except AttributeError: num_attention_heads, num_key_value_heads = ( model.config.text_config.num_attention_heads, model.config.text_config.num_key_value_heads, ) - assert num_attention_heads % ulysses_sp_size == 0, ( - f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" - ) - assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, ( + assert ( + num_attention_heads % ulysses_sp_size == 0 + ), f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" + assert ( + num_key_value_heads % ulysses_sp_size == 0 + or ulysses_sp_size % num_key_value_heads == 0 + ), ( f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size " f"{ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0," f"kv heads are repeated to ensure correctness." @@ -238,7 +270,9 @@ def state_dict(self, *args, **kwargs): # TODO: VLM models only, unify monkey patch to LLM models. if model.config.model_type == "qwen2_5_vl": if is_transformers_version_in_range(min_version="4.53.0"): - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLAttention, + ) # TODO: Support transformers 4.53 raise ValueError("Transformers 4.53 is not supported") @@ -255,11 +289,15 @@ def state_dict(self, *args, **kwargs): if ulysses_sp_size > 1: if is_transformers_version_in_range(min_version="4.52.0"): - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLTextModel, + ) patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel) else: - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLModel, + ) patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel) @@ -270,7 +308,9 @@ def state_dict(self, *args, **kwargs): # TODO: Support transformers 4.53 raise ValueError("Transformers 4.53 is not supported") else: - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLFlashAttention2 as Qwen2VLAttention, + ) if use_remove_padding or ulysses_sp_size > 1: from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward @@ -280,7 +320,9 @@ def state_dict(self, *args, **kwargs): if ulysses_sp_size > 1: if is_transformers_version_in_range(min_version="4.52.0"): - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLTextModel, + ) patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel) else: @@ -314,13 +356,21 @@ def state_dict(self, *args, **kwargs): from transformers.integrations import flash_attention flash_attention._flash_attention_forward = _ulysses_flash_attention_forward - print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}") + print( + f"Monkey patch _flash_attention_forward in {flash_attention.__name__}" + ) - patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend) + patch_forward_with_backends( + model, + use_fused_kernels=use_fused_kernels, + fused_kernels_backend=fused_kernels_backend, + ) @lru_cache -def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool: +def is_transformers_version_in_range( + min_version: Optional[str] = None, max_version: Optional[str] = None +) -> bool: try: # Get the installed version of the transformers library transformers_version_str = importlib.metadata.version("transformers") diff --git a/Agent0/executor_train/verl/verl/models/transformers/npu_patch.py b/Agent0/executor_train/verl/verl/models/transformers/npu_patch.py index e6bb373..54af9ce 100644 --- a/Agent0/executor_train/verl/verl/models/transformers/npu_patch.py +++ b/Agent0/executor_train/verl/verl/models/transformers/npu_patch.py @@ -33,10 +33,14 @@ def apply_rotary_pos_emb_flashatt_npu( cos = cos.repeat(1, 2) sin = sin.repeat(1, 2) q_embed = apply_rotary_emb( - q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float() + q.float(), + cos.unsqueeze(0).unsqueeze(2).float(), + sin.unsqueeze(0).unsqueeze(2).float(), ).type_as(q) k_embed = apply_rotary_emb( - k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float() + k.float(), + cos.unsqueeze(0).unsqueeze(2).float(), + sin.unsqueeze(0).unsqueeze(2).float(), ).type_as(k) return q_embed, k_embed diff --git a/Agent0/executor_train/verl/verl/models/transformers/qwen2.py b/Agent0/executor_train/verl/verl/models/transformers/qwen2.py index e55fb26..78e2a29 100644 --- a/Agent0/executor_train/verl/verl/models/transformers/qwen2.py +++ b/Agent0/executor_train/verl/verl/models/transformers/qwen2.py @@ -39,7 +39,9 @@ def qwen2_flash_attn_forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 ): """ Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. @@ -82,8 +84,14 @@ def qwen2_flash_attn_forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -196,7 +204,9 @@ def qwen2_attn_forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) sliding_window = None if ( @@ -210,14 +220,18 @@ def qwen2_attn_forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + if self.config._attn_implementation == "sdpa" and kwargs.get( + "output_attentions", False + ): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " "Falling back to eager attention. This warning can be removed using the argument " '`attn_implementation="eager"` when loading the model.' ) else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] attn_output, attn_weights = attention_interface( self, diff --git a/Agent0/executor_train/verl/verl/models/transformers/qwen2_5_vl.py b/Agent0/executor_train/verl/verl/models/transformers/qwen2_5_vl.py index 51d9753..614b34c 100644 --- a/Agent0/executor_train/verl/verl/models/transformers/qwen2_5_vl.py +++ b/Agent0/executor_train/verl/verl/models/transformers/qwen2_5_vl.py @@ -51,11 +51,19 @@ def forward_base_model( Copy paste Qwen2_5_VL's forward https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) @@ -103,7 +111,9 @@ def forward_base_model( # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only - if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + if ( + cache_position is not None and cache_position[0] == 0 + ) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, @@ -115,7 +125,11 @@ def forward_base_model( # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape - delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` @@ -193,7 +207,9 @@ def forward_with_torch_backend( elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: - raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + raise RuntimeError( + "To use forward_with_torch_backend, either labels or input_ids must be provided." + ) fused_linear_for_ppo = FusedLinearForPPO() log_probs, entropy = fused_linear_for_ppo.forward( @@ -268,7 +284,9 @@ def forward_with_triton_backend( elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: - raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + raise RuntimeError( + "To use forward_with_triton_backend, either labels or input_ids must be provided." + ) log_probs, entropy = linear_cross_entropy( hidden_states, diff --git a/Agent0/executor_train/verl/verl/models/transformers/qwen2_vl.py b/Agent0/executor_train/verl/verl/models/transformers/qwen2_vl.py index 358b00b..831081f 100644 --- a/Agent0/executor_train/verl/verl/models/transformers/qwen2_vl.py +++ b/Agent0/executor_train/verl/verl/models/transformers/qwen2_vl.py @@ -33,9 +33,14 @@ ) try: - from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func + from transformers.modeling_flash_attention_utils import ( + flash_attn_func, + flash_attn_varlen_func, + ) - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + _flash_supports_window_size = "window_size" in list( + inspect.signature(flash_attn_func).parameters + ) except ImportError: flash_attn_varlen_func = None @@ -57,12 +62,18 @@ def get_rope_index( tokens_per_second = 2 image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") - vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") - if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + vision_start_token_id = processor.tokenizer.convert_tokens_to_ids( + "<|vision_start|>" + ) + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): if attention_mask is None: attention_mask = torch.ones_like(input_ids) - position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) + position_ids = torch.ones( + 3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device + ) # (3, seqlen) image_index, video_index = 0, 0 input_ids = input_ids[attention_mask == 1] image_nums, video_nums = 0, 0 @@ -99,7 +110,11 @@ def get_rope_index( video_grid_thw[video_index][1], video_grid_thw[video_index][2], ) - second_per_grid_t = second_per_grid_ts[video_index] if second_per_grid_ts is not None else 1.0 + second_per_grid_t = ( + second_per_grid_ts[video_index] + if second_per_grid_ts is not None + else 1.0 + ) video_index += 1 remain_videos -= 1 @@ -113,19 +128,37 @@ def get_rope_index( text_len = ed - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) - t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) + t_index = ( + torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) + ) t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) @@ -135,27 +168,47 @@ def get_rope_index( position_ids.masked_fill_(attention_mask == 0, 1) position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) else: - position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, -1) + .expand(3, -1) + ) return position_ids def prepare_fa2_from_position_ids( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + position_ids: torch.Tensor, ): query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) value = value.view(-1, value.size(-2), value.size(-1)) position_ids = position_ids.flatten() - indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + indices_q = torch.arange( + position_ids.size(0), device=position_ids.device, dtype=torch.int32 + ) cu_seqlens = torch.cat( ( indices_q[position_ids == 0], - torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + torch.tensor( + position_ids.size(), device=position_ids.device, dtype=torch.int32 + ), ) ) - max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope - return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) + max_length = ( + cu_seqlens.diff().max() + ) # use cu_seqlens to infer max_length for qwen2vl mrope + return ( + query, + key, + value, + indices_q, + (cu_seqlens, cu_seqlens), + (max_length, max_length), + ) def flash_attention_forward( @@ -178,19 +231,29 @@ def flash_attention_forward( # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). use_sliding_windows = ( - _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + _flash_supports_window_size + and sliding_window is not None + and key_states.shape[1] > sliding_window + ) + flash_kwargs = ( + {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} ) - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} if is_flash_attn_greater_or_equal("2.4.1"): if deterministic is None: deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" flash_kwargs["deterministic"] = deterministic - if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all(): + if ( + position_ids is not None + and query_length != 1 + and not (torch.diff(position_ids[0], dim=-1) >= 0).all() + ): batch_size = query_states.size(0) - query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( - query_states, key_states, value_states, position_ids[0] + query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = ( + prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids[0] + ) ) # remove channel dimension cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -207,7 +270,9 @@ def flash_attention_forward( causal=causal, **flash_kwargs, ) - attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + attn_output = attn_output.view( + batch_size, -1, attn_output.size(-2), attn_output.size(-1) + ) else: attn_output = _flash_attention_forward( query_states, @@ -230,19 +295,32 @@ def ulysses_flash_attn_forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 **kwargs, ) -> tuple[torch.Tensor, None, None]: - from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + apply_multimodal_rotary_pos_emb, + repeat_kv, + ) bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size - query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) + query_states = self.q_proj( + hidden_states + ) # (batch_size, seq_length / sp_size, num_heads * head_size) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) ulysses_sp_size = get_ulysses_sequence_parallel_world_size() @@ -332,11 +410,19 @@ def forward_base_model( Copy paste Qwen2VL's forward https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) @@ -383,13 +469,21 @@ def forward_base_model( if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only - if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: - position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) + if ( + cache_position is not None and cache_position[0] == 0 + ) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape - delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + delta = ( + cache_position[0] + self.rope_deltas + if cache_position is not None + else 0 + ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` @@ -466,7 +560,9 @@ def forward_with_torch_backend( elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: - raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") + raise RuntimeError( + "To use forward_with_torch_backend, either labels or input_ids must be provided." + ) fused_linear_for_ppo = FusedLinearForPPO() log_probs, entropy = fused_linear_for_ppo.forward( @@ -539,7 +635,9 @@ def forward_with_triton_backend( elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: - raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") + raise RuntimeError( + "To use forward_with_triton_backend, either labels or input_ids must be provided." + ) log_probs, entropy = linear_cross_entropy( hidden_states, diff --git a/Agent0/executor_train/verl/verl/protocol.py b/Agent0/executor_train/verl/verl/protocol.py index 0029913..61324a3 100644 --- a/Agent0/executor_train/verl/verl/protocol.py +++ b/Agent0/executor_train/verl/verl/protocol.py @@ -51,12 +51,17 @@ class _DataProtoConfigMeta(type): @property def auto_padding(cls): - enabled_by_env = os.getenv("VERL_AUTO_PADDING", "FALSE").upper() in ["TRUE", "1"] + enabled_by_env = os.getenv("VERL_AUTO_PADDING", "FALSE").upper() in [ + "TRUE", + "1", + ] return enabled_by_env or cls._config.get(cls.auto_padding_key, False) @auto_padding.setter def auto_padding(cls, enabled: bool): - assert isinstance(enabled, bool), f"enabled must be a boolean, got {enabled} as {type(enabled)}" + assert isinstance( + enabled, bool + ), f"enabled must be a boolean, got {enabled} as {type(enabled)}" cls._config[cls.auto_padding_key] = enabled @@ -104,29 +109,31 @@ def unpad_dataproto(data: "DataProto", pad_size): def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: """Union two tensordicts.""" - assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( - f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" - ) + assert ( + tensor_dict1.batch_size == tensor_dict2.batch_size + ), f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" for key in tensor_dict2.keys(): if key not in tensor_dict1.keys(): tensor_dict1[key] = tensor_dict2[key] else: - assert tensor_dict1[key].equal(tensor_dict2[key]), ( - f"{key} in tensor_dict1 and tensor_dict2 are not the same object" - ) + assert tensor_dict1[key].equal( + tensor_dict2[key] + ), f"{key} in tensor_dict1 and tensor_dict2 are not the same object" return tensor_dict1 -def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]: +def union_numpy_dict( + tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray] +) -> dict[str, np.ndarray]: for key, val in tensor_dict2.items(): if key in tensor_dict1: assert isinstance(tensor_dict2[key], np.ndarray) assert isinstance(tensor_dict1[key], np.ndarray) # to properly deal with nan and object type - assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), ( - f"{key} in tensor_dict1 and tensor_dict2 are not the same object" - ) + assert pd.DataFrame(tensor_dict2[key]).equals( + pd.DataFrame(tensor_dict1[key]) + ), f"{key} in tensor_dict1 and tensor_dict2 are not the same object" tensor_dict1[key] = val return tensor_dict1 @@ -161,7 +168,9 @@ def fold_batch_dim(data: "DataProto", new_batch_size): for key, val in non_tensor.items(): non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:])) - return type(data)(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info) + return type(data)( + batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info + ) def unfold_batch_dim(data: "DataProto", batch_dims=2): @@ -178,9 +187,13 @@ def unfold_batch_dim(data: "DataProto", batch_dims=2): non_tensor_new = {} for key, val in non_tensor.items(): - non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:])) + non_tensor_new[key] = np.reshape( + val, newshape=(batch_size, *val.shape[batch_dims:]) + ) - return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info) + return type(data)( + batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info + ) def collate_fn(x: list["DataProtoItem"]): @@ -257,8 +270,14 @@ def __getitem__(self, item): # Case 3: Single integer - return DataProtoItem for backward compatibility elif isinstance(item, int | np.integer): tensor_data = self.batch[item] if self.batch is not None else None - non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} - return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) + non_tensor_data = { + key: val[item] for key, val in self.non_tensor_batch.items() + } + return DataProtoItem( + batch=tensor_data, + non_tensor_batch=non_tensor_data, + meta_info=self.meta_info, + ) # # Case 4: Unsupported type else: @@ -268,7 +287,10 @@ def __getstate__(self): import io buffer = io.BytesIO() - if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: + if ( + version.parse(tensordict.__version__) >= version.parse("0.5.0") + and self.batch is not None + ): self.batch = self.batch.contiguous() self.batch = self.batch.consolidate() torch.save(self.batch, buffer) @@ -328,9 +350,15 @@ def check_consistency(self): for key, val in self.non_tensor_batch.items(): assert isinstance(val, np.ndarray) - if self.batch is not None and self.non_tensor_batch is not None and len(self.non_tensor_batch) != 0: + if ( + self.batch is not None + and self.non_tensor_batch is not None + and len(self.non_tensor_batch) != 0 + ): # TODO: we can actually lift this restriction if needed - assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty." + assert ( + len(self.batch.batch_size) == 1 + ), "only support num_batch_dims=1 when non_tensor_batch is not empty." batch_size = self.batch.batch_size[0] for key, val in self.non_tensor_batch.items(): @@ -338,12 +366,17 @@ def check_consistency(self): f"data in the non_tensor_batch must be a numpy.array with dtype=object, but for " f"{key=}, got {type(val)=}" ) - assert val.shape[0] == batch_size, ( - f"key {key} length {len(val)} is not equal to batch size {batch_size}" - ) + assert ( + val.shape[0] == batch_size + ), f"key {key} length {len(val)} is not equal to batch size {batch_size}" @classmethod - def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None, auto_padding=False): + def from_single_dict( + cls, + data: dict[str, torch.Tensor | np.ndarray], + meta_info=None, + auto_padding=False, + ): """Create a DataProto from a dict of tensors and non_tensors""" tensors = {} non_tensors = {} @@ -356,7 +389,12 @@ def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info= else: raise ValueError(f"Unsupported type in data {type(val)}") - return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info, auto_padding=auto_padding) + return cls.from_dict( + tensors=tensors, + non_tensors=non_tensors, + meta_info=meta_info, + auto_padding=auto_padding, + ) @classmethod def from_dict( @@ -374,7 +412,9 @@ def from_dict( assert num_batch_dims > 0, "num_batch_dims must be greater than zero" if non_tensors is not None: - assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None." + assert ( + num_batch_dims == 1 + ), "only support num_batch_dims=1 when non_tensors is not None." if tensors is None: tensors = {} @@ -403,7 +443,9 @@ def from_dict( if not isinstance(val, np.ndarray): non_tensors[key] = np.array(val, dtype=object) - tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None + tensor_dict = ( + TensorDict(source=tensors, batch_size=batch_size) if tensors else None + ) if auto_padding: meta_info[DataProtoConfig.auto_padding_key] = True return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) @@ -422,7 +464,13 @@ def to(self, device) -> "DataProto": self.batch = self.batch.to(device) return self - def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto": + def select( + self, + batch_keys=None, + non_tensor_batch_keys=None, + meta_info_keys=None, + deepcopy=False, + ) -> "DataProto": """Select a subset of the DataProto via batch_keys and meta_info_keys Args: @@ -440,7 +488,11 @@ def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=Non sub_batch = self.batch if non_tensor_batch_keys is not None: - non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys} + non_tensor_batch = { + key: val + for key, val in self.non_tensor_batch.items() + if key in non_tensor_batch_keys + } else: non_tensor_batch = self.non_tensor_batch @@ -448,14 +500,18 @@ def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=Non non_tensor_batch = copy.deepcopy(non_tensor_batch) if meta_info_keys is not None: - sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys} + sub_meta_info = { + key: val for key, val in self.meta_info.items() if key in meta_info_keys + } else: sub_meta_info = self.meta_info if deepcopy: sub_meta_info = copy.deepcopy(sub_meta_info) - return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) + return type(self)( + batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info + ) def select_idxs(self, idxs): """ @@ -495,7 +551,11 @@ def select_idxs(self, idxs): for key, val in self.non_tensor_batch.items(): selected_non_tensor[key] = val[idxs_np] - return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info) + return type(self)( + batch=selected_batch, + non_tensor_batch=selected_non_tensor, + meta_info=self.meta_info, + ) def slice(self, start=None, end=None, step=None): """ @@ -541,9 +601,15 @@ def slice(self, start=None, end=None, step=None): sliced_non_tensor[key] = val[slice_obj] # Return a new DataProto object - return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info) + return type(self)( + batch=sliced_batch, + non_tensor_batch=sliced_non_tensor, + meta_info=self.meta_info, + ) - def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto": + def pop( + self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None + ) -> "DataProto": """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` Args: @@ -574,7 +640,9 @@ def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) for key in meta_info_keys: assert key in self.meta_info.keys() meta_info[key] = self.meta_info.pop(key) - return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) + return DataProto.from_dict( + tensors=tensors, non_tensors=non_tensors, meta_info=meta_info + ) def rename(self, old_keys=None, new_keys=None) -> "DataProto": """ @@ -588,7 +656,9 @@ def validate_input(keys): elif isinstance(keys, list): pass else: - raise TypeError(f"keys must be a list or a string, but got {type(keys)}") + raise TypeError( + f"keys must be a list or a string, but got {type(keys)}" + ) return keys old_keys = validate_input(old_keys) @@ -618,7 +688,9 @@ def union(self, other: "DataProto") -> "DataProto": DataProto: the DataProto after union """ self.batch = union_tensor_dict(self.batch, other.batch) - self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) + self.non_tensor_batch = union_numpy_dict( + self.non_tensor_batch, other.non_tensor_batch + ) self.meta_info = union_two_dict(self.meta_info, other.meta_info) return self @@ -638,7 +710,9 @@ def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=No Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is ``self.batch.batch_size * epochs // mini_batch_size`` """ - assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" + assert ( + self.batch.batch_size[0] % mini_batch_size == 0 + ), f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" # we can directly create a dataloader from TensorDict if dataloader_kwargs is None: dataloader_kwargs = {} @@ -651,7 +725,11 @@ def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=No assert isinstance(dataloader_kwargs, dict) train_dataloader = DataLoader( - dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs + dataset=self, + batch_size=mini_batch_size, + collate_fn=collate_fn, + generator=generator, + **dataloader_kwargs, ) def get_data(): @@ -668,7 +746,9 @@ def is_padding_enabled(self): Returns: bool: True if padding is enabled, False otherwise. """ - dataproto_specific_padding = self.meta_info.get(DataProtoConfig.auto_padding_key, False) + dataproto_specific_padding = self.meta_info.get( + DataProtoConfig.auto_padding_key, False + ) return dataproto_specific_padding or DataProtoConfig.auto_padding def padding(self, padding_size, padding_candidate=""): @@ -680,7 +760,9 @@ def padding(self, padding_size, padding_candidate=""): """ if padding_size == 0: return - padding_candidate = self.select_idxs([0 if padding_candidate == "first" else len(self) - 1]) + padding_candidate = self.select_idxs( + [0 if padding_candidate == "first" else len(self) - 1] + ) padding_part = padding_candidate.repeat(padding_size) padded_dp = DataProto.concat([self, padding_part]) self.batch = padded_dp.batch @@ -696,9 +778,9 @@ def chunk(self, chunks: int) -> list["DataProto"]: List[DataProto]: a list of DataProto after splitting """ if not self.is_padding_enabled(): - assert len(self) % chunks == 0, ( - f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." - ) + assert ( + len(self) % chunks == 0 + ), f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." bsz_in_batch = None if self.batch is not None: @@ -722,7 +804,11 @@ def chunk(self, chunks: int) -> list["DataProto"]: output = [] for i in range(chunks): output.append( - type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info) + type(self)( + batch=batch_lst[i], + non_tensor_batch=non_tensor_batch_lst[i], + meta_info=self.meta_info, + ) ) return output @@ -743,12 +829,18 @@ def concat(data: list["DataProto"]) -> "DataProto": batch_lst.append(batch.batch) new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None - non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data]) + non_tensor_batch = list_of_dict_to_dict_of_list( + list_of_dict=[d.non_tensor_batch for d in data] + ) for key, val in non_tensor_batch.items(): non_tensor_batch[key] = np.concatenate(val, axis=0) cls = type(data[0]) if len(data) > 0 else DataProto - return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) + return cls( + batch=new_batch, + non_tensor_batch=non_tensor_batch, + meta_info=data[0].meta_info, + ) def reorder(self, indices): """ @@ -756,7 +848,9 @@ def reorder(self, indices): """ indices_np = indices.detach().numpy() self.batch = self.batch[indices] - self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} + self.non_tensor_batch = { + key: val[indices_np] for key, val in self.non_tensor_batch.items() + } def repeat(self, repeat_times=2, interleave=True): """ @@ -773,12 +867,15 @@ def repeat(self, repeat_times=2, interleave=True): if interleave: # Interleave the data repeated_tensors = { - key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + key: tensor.repeat_interleave(repeat_times, dim=0) + for key, tensor in self.batch.items() } else: # Stack the data repeated_tensors = { - key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) + key: tensor.unsqueeze(0) + .expand(repeat_times, *tensor.shape) + .reshape(-1, *tensor.shape[1:]) for key, tensor in self.batch.items() } @@ -794,7 +891,9 @@ def repeat(self, repeat_times=2, interleave=True): if interleave: repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) else: - repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) + repeated_non_tensor_batch[key] = np.tile( + val, (repeat_times,) + (1,) * (val.ndim - 1) + ) return type(self)( batch=repeated_batch, @@ -802,7 +901,9 @@ def repeat(self, repeat_times=2, interleave=True): meta_info=self.meta_info, ) - def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None): + def unfold_column_chunks( + self, n_split: int, split_keys: Optional[list[str]] = None + ): """Split along the second dim into `n_split`, unfold it to the first dim (batch dim) Useful in passing grouped tensors that doesn't want to be shuffled in dataset. keys not in split_keys are repeated to match the shape @@ -817,10 +918,14 @@ def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = N shape[1] = self.batch[key].shape[1] // n_split unfolded_batch[key] = self.batch[key].reshape(*shape) else: - unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0) + unfolded_batch[key] = torch.repeat_interleave( + self.batch[key], n_split, dim=0 + ) # locate the `unfolded_batch` as a TensorDict on the same device as the original batch unfolded_batch = TensorDict( - source=unfolded_batch, batch_size=(self.batch.batch_size[0] * n_split,), device=self.batch.device + source=unfolded_batch, + batch_size=(self.batch.batch_size[0] * n_split,), + device=self.batch.device, ) else: unfolded_batch = None @@ -860,15 +965,16 @@ def sample_level_repeat(self, repeat_times): assert len(repeat_times.shape) == 1 repeat_times = repeat_times.tolist() else: - assert isinstance(repeat_times, list), ( - f"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}" - ) + assert isinstance( + repeat_times, list + ), f"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}" repeat_times = torch.tensor(repeat_times) if self.batch is not None: # Interleave the data repeated_tensors = { - key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + key: tensor.repeat_interleave(repeat_times, dim=0) + for key, tensor in self.batch.items() } repeated_batch = TensorDict( @@ -924,7 +1030,9 @@ def dispatch_fn(x, i, chunks): return x.chunk(chunks=chunks)[i] arg_future = DataProtoFuture( - collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures + collect_fn=self.collect_fn, + dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), + futures=self.futures, ) arg_future_lst.append(arg_future) return arg_future_lst @@ -945,9 +1053,16 @@ def all_gather_data_proto(data: DataProto, process_group): assert isinstance(data, DataProto) prev_device = data.batch.device data.batch = data.batch.to(get_device_id()) - data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0) + data.batch = allgather_dict_tensors( + data.batch.contiguous(), size=group_size, group=process_group, dim=0 + ) data.batch = data.batch.to(prev_device) # all gather non_tensor_batch all_non_tensor_batch = [None for _ in range(group_size)] - torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=process_group) - data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch} + torch.distributed.all_gather_object( + all_non_tensor_batch, data.non_tensor_batch, group=process_group + ) + data.non_tensor_batch = { + k: np.concatenate([d[k] for d in all_non_tensor_batch]) + for k in data.non_tensor_batch + } diff --git a/Agent0/executor_train/verl/verl/single_controller/__init__.py b/Agent0/executor_train/verl/verl/single_controller/__init__.py index ad6c42a..2cb36d5 100644 --- a/Agent0/executor_train/verl/verl/single_controller/__init__.py +++ b/Agent0/executor_train/verl/verl/single_controller/__init__.py @@ -19,7 +19,9 @@ version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) # Note(haibin.lin): single_controller.__version__ is deprecated -with open(os.path.join(os.path.join(version_folder, os.pardir), "version/version")) as f: +with open( + os.path.join(os.path.join(version_folder, os.pardir), "version/version") +) as f: __version__ = f.read().strip() diff --git a/Agent0/executor_train/verl/verl/single_controller/base/decorator.py b/Agent0/executor_train/verl/verl/single_controller/base/decorator.py index 1008a79..31caa8c 100644 --- a/Agent0/executor_train/verl/verl/single_controller/base/decorator.py +++ b/Agent0/executor_train/verl/verl/single_controller/base/decorator.py @@ -103,12 +103,16 @@ def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs): # for padding, we only support DataProto with same length if data_proto_len is None: data_proto_len = len(arg) - padding_size = (chunks - (data_proto_len % chunks)) if (data_proto_len % chunks > 0) else 0 + padding_size = ( + (chunks - (data_proto_len % chunks)) + if (data_proto_len % chunks > 0) + else 0 + ) splitted_kwargs[_padding_size_key] = padding_size else: - assert data_proto_len == len(arg), ( - f"expecting all arg share same length of {data_proto_len}, but got {len(arg)}" - ) + assert data_proto_len == len( + arg + ), f"expecting all arg share same length of {data_proto_len}, but got {len(arg)}" data_proto_len = len(arg) arg.padding(padding_size=padding_size) @@ -123,9 +127,9 @@ def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs): padding_size = chunks - (data_proto_len % chunks) splitted_kwargs[_padding_size_key] = padding_size else: - assert data_proto_len == len(val), ( - f"expecting all arg share same length of {data_proto_len}, but got {len(val)}" - ) + assert data_proto_len == len( + val + ), f"expecting all arg share same length of {data_proto_len}, but got {len(val)}" data_proto_len = len(val) splitted_kwargs[key] = val.chunk(chunks=chunks) @@ -156,9 +160,9 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs): """ from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - assert isinstance(worker_group, MegatronWorkerGroup), ( - f"worker_group must be MegatronWorkerGroup, Got {type(worker_group)}" - ) + assert isinstance( + worker_group, MegatronWorkerGroup + ), f"worker_group must be MegatronWorkerGroup, Got {type(worker_group)}" # ray put all the args in advance to avoid duplicate serialization cost import ray @@ -198,7 +202,11 @@ def collect_megatron_compute(worker_group, output): pp_size = worker_group.get_megatron_global_info().pp_size for global_rank in range(worker_group.world_size): local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) - if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1 and local_rank_info.cp_rank == 0: + if ( + local_rank_info.tp_rank == 0 + and local_rank_info.pp_rank == pp_size - 1 + and local_rank_info.cp_rank == 0 + ): output_in_dp.append(output[global_rank]) return output_in_dp @@ -211,7 +219,9 @@ def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs): assert isinstance(worker_group, MegatronWorkerGroup) - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs) + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto( + worker_group.dp_size, *args, **kwargs + ) return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs) @@ -244,7 +254,9 @@ def collect_megatron_compute_data_proto(worker_group, output): output = collect_megatron_compute(worker_group, output) for o in output: - assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}" + assert isinstance( + o, DataProto | ray.ObjectRef + ), f"expecting {o} to be DataProto, but got {type(o)}" return _concat_data_proto_or_future(output) @@ -289,7 +301,9 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): all_kwargs = {} for k, v in kwargs.items(): - assert isinstance(v, list | tuple) and len(v) == pp_dp_cp_size, f"expect len(v)=={pp_dp_cp_size}, got {len(v)}" + assert ( + isinstance(v, list | tuple) and len(v) == pp_dp_cp_size + ), f"expect len(v)=={pp_dp_cp_size}, got {len(v)}" transformed_v = [] for i in range(worker_group.world_size): local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank @@ -339,7 +353,9 @@ def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): assert isinstance(worker_group, MegatronWorkerGroup) pp_dp_cp_size = worker_group.dp_size * worker_group.pp_size * worker_group.cp_size - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_cp_size, *args, **kwargs) + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto( + pp_dp_cp_size, *args, **kwargs + ) ret = dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs) return ret @@ -391,7 +407,9 @@ def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): assert isinstance(worker_group, WorkerGroup) assert isinstance(args[0], FunctionType) # NOTE: The first one args is a function! - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) + splitted_args, splitted_kwargs = _split_args_kwargs_data_proto( + worker_group.world_size, *args[1:], **kwargs + ) splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args return splitted_args_with_func, splitted_kwargs @@ -402,7 +420,9 @@ def collect_dp_compute_data_proto(worker_group, output): from verl.protocol import DataProto for o in output: - assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}" + assert isinstance( + o, DataProto | ray.ObjectRef + ), f"expecting {o} to be DataProto, but got {type(o)}" output = collect_dp_compute(worker_group, output) return _concat_data_proto_or_future(output) @@ -426,7 +446,10 @@ def collect_dp_compute_data_proto(worker_group, output): "dispatch_fn": dispatch_megatron_pp_as_dp, "collect_fn": collect_megatron_pp_as_dp, }, - Dispatch.MEGATRON_PP_ONLY: {"dispatch_fn": dispatch_one_to_all, "collect_fn": collect_megatron_pp_only}, + Dispatch.MEGATRON_PP_ONLY: { + "dispatch_fn": dispatch_one_to_all, + "collect_fn": collect_megatron_pp_only, + }, Dispatch.MEGATRON_COMPUTE_PROTO: { "dispatch_fn": dispatch_megatron_compute_data_proto, "collect_fn": collect_megatron_compute_data_proto, @@ -435,7 +458,10 @@ def collect_dp_compute_data_proto(worker_group, output): "dispatch_fn": dispatch_megatron_pp_as_dp_data_proto, "collect_fn": collect_megatron_pp_as_dp_data_proto, }, - Dispatch.DP_COMPUTE: {"dispatch_fn": dispatch_dp_compute, "collect_fn": collect_dp_compute}, + Dispatch.DP_COMPUTE: { + "dispatch_fn": dispatch_dp_compute, + "collect_fn": collect_dp_compute, + }, Dispatch.DP_COMPUTE_PROTO: { "dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute_data_proto, @@ -444,7 +470,10 @@ def collect_dp_compute_data_proto(worker_group, output): "dispatch_fn": dispatch_dp_compute_data_proto_with_func, "collect_fn": collect_dp_compute_data_proto, }, - Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute}, + Dispatch.DP_COMPUTE_METRIC: { + "dispatch_fn": dispatch_dp_compute_data_proto, + "collect_fn": collect_dp_compute, + }, Dispatch.DIRECT_ROLLOUT_METHOD: { "dispatch_fn": dummy_direct_rollout_call, "collect_fn": dummy_direct_rollout_call, @@ -462,8 +491,13 @@ def register_dispatch_mode(dispatch_mode_name, dispatch_fn, collect_fn): """ dispatch_mode = Dispatch.register(dispatch_mode_name) _check_dispatch_mode(dispatch_mode) - assert dispatch_mode not in DISPATCH_MODE_FN_REGISTRY, f"dispatch_mode_name {dispatch_mode_name} already exists" - DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {"dispatch_fn": dispatch_fn, "collect_fn": collect_fn} + assert ( + dispatch_mode not in DISPATCH_MODE_FN_REGISTRY + ), f"dispatch_mode_name {dispatch_mode_name} already exists" + DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = { + "dispatch_fn": dispatch_fn, + "collect_fn": collect_fn, + } def update_dispatch_mode(dispatch_mode, dispatch_fn, collect_fn): @@ -471,8 +505,13 @@ def update_dispatch_mode(dispatch_mode, dispatch_fn, collect_fn): Update the dispatch mode. """ _check_dispatch_mode(dispatch_mode) - assert dispatch_mode in DISPATCH_MODE_FN_REGISTRY, f"dispatch_mode {dispatch_mode} not found" - DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {"dispatch_fn": dispatch_fn, "collect_fn": collect_fn} + assert ( + dispatch_mode in DISPATCH_MODE_FN_REGISTRY + ), f"dispatch_mode {dispatch_mode} not found" + DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = { + "dispatch_fn": dispatch_fn, + "collect_fn": collect_fn, + } def get_predefined_execute_fn(execute_mode): @@ -488,17 +527,21 @@ def get_predefined_execute_fn(execute_mode): def _check_dispatch_mode(dispatch_mode): - assert isinstance(dispatch_mode, Dispatch | dict), ( - f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" - ) + assert isinstance( + dispatch_mode, Dispatch | dict + ), f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" if isinstance(dispatch_mode, dict): necessary_keys = ["dispatch_fn", "collect_fn"] for key in necessary_keys: - assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" + assert ( + key in dispatch_mode + ), f"key {key} should be in dispatch_mode if it is a dictionary" def _check_execute_mode(execute_mode): - assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" + assert isinstance( + execute_mode, Execute + ), f"execute_mode must be a Execute. Got {execute_mode}" def _materialize_futures(*args, **kwargs): @@ -516,7 +559,12 @@ def _materialize_futures(*args, **kwargs): return new_args, kwargs -def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): +def register( + dispatch_mode=Dispatch.ALL_TO_ALL, + execute_mode=Execute.ALL, + blocking=True, + materialize_futures=True, +): """Register a function with distributed execution configuration. This decorator registers a function with specific dispatch and execution modes @@ -554,7 +602,11 @@ async def async_inner(*args, **kwargs): return await func(*args, **kwargs) wrapper = async_inner if inspect.iscoroutinefunction(func) else inner - attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} + attrs = { + "dispatch_mode": dispatch_mode, + "execute_mode": execute_mode, + "blocking": blocking, + } setattr(wrapper, MAGIC_ATTR, attrs) return wrapper diff --git a/Agent0/executor_train/verl/verl/single_controller/base/megatron/worker.py b/Agent0/executor_train/verl/verl/single_controller/base/megatron/worker.py index baf6eb8..975b697 100644 --- a/Agent0/executor_train/verl/verl/single_controller/base/megatron/worker.py +++ b/Agent0/executor_train/verl/verl/single_controller/base/megatron/worker.py @@ -26,7 +26,9 @@ def get_megatron_global_info(self): dp_size = mpu.get_data_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() cp_size = mpu.get_context_parallel_world_size() - info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size, cp_size=cp_size) + info = DistGlobalInfo( + tp_size=tp_size, dp_size=dp_size, pp_size=pp_size, cp_size=cp_size + ) return info def get_megatron_rank_info(self): @@ -36,7 +38,9 @@ def get_megatron_rank_info(self): dp_rank = mpu.get_data_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() cp_rank = mpu.get_context_parallel_rank() - info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank) + info = DistRankInfo( + tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank + ) return info def _init_hf_config_and_tf_config( @@ -59,11 +63,19 @@ def _init_hf_config_and_tf_config( # Step 1: initialize the tokenizer self.local_path = copy_to_local(model_path) if tokenizer_or_path is None: - self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code) - self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code) + self.tokenizer = hf_tokenizer( + self.local_path, trust_remote_code=trust_remote_code + ) + self.processor = hf_processor( + self.local_path, trust_remote_code=trust_remote_code + ) elif isinstance(tokenizer_or_path, str): - self.tokenizer = hf_tokenizer(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) - self.processor = hf_processor(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) + self.tokenizer = hf_tokenizer( + copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code + ) + self.processor = hf_processor( + copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code + ) else: self.tokenizer = tokenizer_or_path self.processor = tokenizer_or_path @@ -75,7 +87,9 @@ def _init_hf_config_and_tf_config( self.tokenizer.chat_template = self.config.model.custom_chat_template # Step 2: get the hf - hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code) + hf_config = AutoConfig.from_pretrained( + self.local_path, trust_remote_code=trust_remote_code + ) # Step 3: override the hf config override_config_kwargs = { @@ -84,7 +98,9 @@ def _init_hf_config_and_tf_config( "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_model_config.get("model_config", {})) - self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) + self.share_embeddings_and_output_weights = getattr( + hf_config, "tie_word_embeddings", False + ) update_model_config(hf_config, override_config_kwargs=override_config_kwargs) self.architectures = getattr(hf_config, "architectures", None) if self.rank == 0: @@ -94,12 +110,18 @@ def _init_hf_config_and_tf_config( def add_optimization_config_to_tf_config(tf_config): # add optimization config to tf_config, e.g. checkpointing if self.config.model.get("enable_gradient_checkpointing", False): - gradient_checkpointing_cfg = dict(self.config.model.get("gradient_checkpointing_kwargs", dict())) - tf_config.recompute_method = gradient_checkpointing_cfg.get("activations_checkpoint_method", "full") + gradient_checkpointing_cfg = dict( + self.config.model.get("gradient_checkpointing_kwargs", dict()) + ) + tf_config.recompute_method = gradient_checkpointing_cfg.get( + "activations_checkpoint_method", "full" + ) tf_config.recompute_granularity = gradient_checkpointing_cfg.get( "activations_checkpoint_granularity", "full" ) - tf_config.recompute_num_layers = gradient_checkpointing_cfg.get("activations_checkpoint_num_layers", -1) + tf_config.recompute_num_layers = gradient_checkpointing_cfg.get( + "activations_checkpoint_num_layers", -1 + ) if megatron_config := self.config.get("megatron", {}): if extra := megatron_config.get("extra", {}): for k, v in extra.items(): diff --git a/Agent0/executor_train/verl/verl/single_controller/base/megatron/worker_group.py b/Agent0/executor_train/verl/verl/single_controller/base/megatron/worker_group.py index b9beb84..5768041 100644 --- a/Agent0/executor_train/verl/verl/single_controller/base/megatron/worker_group.py +++ b/Agent0/executor_train/verl/verl/single_controller/base/megatron/worker_group.py @@ -25,30 +25,42 @@ def __init__(self, resource_pool: ResourcePool, **kwargs): self._megatron_global_info: DistGlobalInfo = None def init_megatron(self, default_megatron_kwargs: dict = None): - raise NotImplementedError("MegatronWorkerGroup.init_megatron should be overwritten") + raise NotImplementedError( + "MegatronWorkerGroup.init_megatron should be overwritten" + ) def get_megatron_rank_info(self, rank: int) -> DistRankInfo: - assert 0 <= rank < self.world_size, f"rank must be from [0, world_size), Got {rank}" + assert ( + 0 <= rank < self.world_size + ), f"rank must be from [0, world_size), Got {rank}" return self._megatron_rank_info[rank] @property def tp_size(self): - assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + assert ( + self._megatron_global_info is not None + ), "MegatronWorkerGroup._megatron_global_info must be initialized" return self._megatron_global_info.tp_size @property def dp_size(self): - assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + assert ( + self._megatron_global_info is not None + ), "MegatronWorkerGroup._megatron_global_info must be initialized" return self._megatron_global_info.dp_size @property def pp_size(self): - assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + assert ( + self._megatron_global_info is not None + ), "MegatronWorkerGroup._megatron_global_info must be initialized" return self._megatron_global_info.pp_size @property def cp_size(self): - assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" + assert ( + self._megatron_global_info is not None + ), "MegatronWorkerGroup._megatron_global_info must be initialized" return self._megatron_global_info.cp_size def get_megatron_global_info(self): diff --git a/Agent0/executor_train/verl/verl/single_controller/base/worker.py b/Agent0/executor_train/verl/verl/single_controller/base/worker.py index 561b9ba..2cd856b 100644 --- a/Agent0/executor_train/verl/verl/single_controller/base/worker.py +++ b/Agent0/executor_train/verl/verl/single_controller/base/worker.py @@ -96,8 +96,13 @@ def __new__(cls, *args, **kwargs): worker_group_prefix = os.environ.get("WG_PREFIX", None) # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init - if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: - instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) + if ( + None not in [rank, worker_group_prefix] + and "ActorClass(" not in cls.__name__ + ): + instance._configure_before_init( + f"{worker_group_prefix}_register_center", int(rank) + ) return instance @@ -120,7 +125,9 @@ def _configure_before_init(self, register_center_name: str, rank: int): } if os.getenv("WG_BACKEND", None) == "ray": - from verl.single_controller.base.register_center.ray import create_worker_group_register_center + from verl.single_controller.base.register_center.ray import ( + create_worker_group_register_center, + ) self.register_center = create_worker_group_register_center( name=register_center_name, info=rank_zero_info @@ -131,7 +138,11 @@ def _configure_before_init(self, register_center_name: str, rank: int): self.register_center = ray.get_actor(register_center_name) # set worker info for node affinity scheduling - ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id())) + ray.get( + self.register_center.set_worker_info.remote( + rank, ray.get_runtime_context().get_node_id() + ) + ) @classmethod def env_keys(cls): @@ -230,7 +241,9 @@ def _setup_env_cuda_visible_devices(self): # Otherwise, we will set ROCR_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES # and remove ROCR_VISIBLE_DEVICES. if cuda_val: - raise ValueError("Please don't set ROCR_VISIBLE_DEVICES when HIP/CUDA_VISIBLE_DEVICES is set.") + raise ValueError( + "Please don't set ROCR_VISIBLE_DEVICES when HIP/CUDA_VISIBLE_DEVICES is set." + ) cuda_val = os.environ.pop("ROCR_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = cuda_val @@ -249,7 +262,10 @@ def _configure_with_store(self, store: dict): """ This function should only be called inside by WorkerGroup """ - store_env_dict = {f"_{key.lower()}": store.get(f"_{key.lower()}", None) for key in type(self).env_keys()} + store_env_dict = { + f"_{key.lower()}": store.get(f"_{key.lower()}", None) + for key in type(self).env_keys() + } self.__dict__.update(store_env_dict) # this is hacky # print(f"__dict__: {self.__dict__}") for key in type(self).env_keys(): @@ -258,7 +274,9 @@ def _configure_with_store(self, store: dict): # print(f"set {key} to {val}") os.environ[key] = str(val) os.environ["REDIS_STORE_SERVER_HOST"] = ( - str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" + str(self._master_addr).replace("[", "").replace("]", "") + if self._master_addr + else "" ) def get_master_addr_port(self): @@ -269,7 +287,9 @@ def get_cuda_visible_devices(self): """Get the CUDA visible devices configuration.""" import os - visible_devices = os.environ.get(get_visible_devices_keyword().upper(), "not set") + visible_devices = os.environ.get( + get_visible_devices_keyword().upper(), "not set" + ) return visible_devices @property diff --git a/Agent0/executor_train/verl/verl/single_controller/base/worker_group.py b/Agent0/executor_train/verl/verl/single_controller/base/worker_group.py index cb86ab4..a83d5d9 100644 --- a/Agent0/executor_train/verl/verl/single_controller/base/worker_group.py +++ b/Agent0/executor_train/verl/verl/single_controller/base/worker_group.py @@ -21,7 +21,12 @@ import time from typing import Any, Callable -from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn +from .decorator import ( + MAGIC_ATTR, + Dispatch, + get_predefined_dispatch_fn, + get_predefined_execute_fn, +) class ResourcePool: @@ -31,7 +36,9 @@ class ResourcePool: across all nodes in the pool. """ - def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8) -> None: + def __init__( + self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8 + ) -> None: """Initialize the ResourcePool with node processes and GPU configuration. Args: @@ -63,13 +70,16 @@ def store(self): def local_world_size_list(self) -> list[int]: """Returns a flat list where each process has its local world size.""" nested_local_world_size_list = [ - [local_world_size for _ in range(local_world_size)] for local_world_size in self._store + [local_world_size for _ in range(local_world_size)] + for local_world_size in self._store ] return [item for row in nested_local_world_size_list for item in row] def local_rank_list(self) -> list[int]: """Returns a flat list of local ranks for all processes across all nodes.""" - nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] + nested_local_rank_list = [ + [i for i in range(local_world_size)] for local_world_size in self._store + ] return [item for row in nested_local_rank_list for item in row] @@ -115,7 +125,9 @@ def check_workers_alive(workers: list, is_alive: Callable, gap_time: float = 1) while True: for worker in workers: if not is_alive(worker): - logging.warning(f"worker {worker} is not alive sending signal to main thread") + logging.warning( + f"worker {worker} is not alive sending signal to main thread" + ) signal.raise_signal(signal.SIGABRT) time.sleep(gap_time) @@ -149,7 +161,9 @@ def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: def _is_worker_alive(self, worker): """Check if a worker is alive. Must be implemented by derived classes.""" - raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") + raise NotImplementedError( + "WorkerGroup._is_worker_alive called, should be implemented in derived class." + ) def _block_until_all_workers_alive(self) -> None: """Blocks until all workers in the group are alive.""" @@ -170,7 +184,8 @@ def start_worker_aliveness_check(self, every_n_seconds=1) -> None: self._block_until_all_workers_alive() self._checker_thread = threading.Thread( - target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds) + target=check_workers_alive, + args=(self._workers, self._is_worker_alive, every_n_seconds), ) self._checker_thread.start() @@ -193,7 +208,9 @@ def _bind_worker_method(self, user_defined_cls, func_generator): for method_name in dir(user_defined_cls): try: method = getattr(user_defined_cls, method_name) - assert callable(method), f"{method_name} in {user_defined_cls} is not callable" + assert callable( + method + ), f"{method_name} in {user_defined_cls} is not callable" except Exception: # if it is a property, it will fail because Class doesn't have instance property continue @@ -201,8 +218,12 @@ def _bind_worker_method(self, user_defined_cls, func_generator): if hasattr(method, MAGIC_ATTR): # this method is decorated by register attribute = getattr(method, MAGIC_ATTR) - assert isinstance(attribute, dict), f"attribute must be a dictionary. Got {type(attribute)}" - assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" + assert isinstance( + attribute, dict + ), f"attribute must be a dictionary. Got {type(attribute)}" + assert ( + "dispatch_mode" in attribute + ), "attribute must contain dispatch_mode in its key" dispatch_mode = attribute["dispatch_mode"] execute_mode = attribute["execute_mode"] diff --git a/Agent0/executor_train/verl/verl/single_controller/ray/base.py b/Agent0/executor_train/verl/verl/single_controller/ray/base.py index bfcf87b..106f9a9 100644 --- a/Agent0/executor_train/verl/verl/single_controller/ray/base.py +++ b/Agent0/executor_train/verl/verl/single_controller/ray/base.py @@ -24,10 +24,18 @@ from ray.experimental.state.api import get_actor from ray.util import list_named_actors from ray.util.placement_group import PlacementGroup, placement_group -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy +from ray.util.scheduling_strategies import ( + NodeAffinitySchedulingStrategy, + PlacementGroupSchedulingStrategy, +) from verl.protocol import DataProto, _padding_size_key -from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup +from verl.single_controller.base import ( + ClassWithInitArgs, + ResourcePool, + Worker, + WorkerGroup, +) from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch __all__ = ["Worker"] @@ -95,17 +103,23 @@ def __init__( super().__init__(process_on_nodes, max_colocate_count) self.use_gpu = use_gpu # print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}") - self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix + self.name_prefix = ( + get_random_string(length=6) if name_prefix is None else name_prefix + ) self.pgs = None self.detached = detached self.accelerator_type = accelerator_type - def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"): + def get_placement_groups( + self, strategy="STRICT_PACK", name=None, device_name="cuda" + ): if self.pgs is not None: return self.pgs pg_name_prefix = ( - name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" + name + if name + else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" ) # print(f"pg_name_prefix = {pg_name_prefix}") if device_name == "npu": @@ -118,12 +132,20 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="c bundle[device_name] = 1 if self.accelerator_type is not None: bundle[self.accelerator_type] = 1e-4 - pg_scheme = [[bundle.copy() for _ in range(process_count)] for process_count in self._store] + pg_scheme = [ + [bundle.copy() for _ in range(process_count)] + for process_count in self._store + ] lifetime = "detached" if self.detached else None pgs = [ - placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) + placement_group( + bundles=bundles, + strategy=strategy, + name=pg_name_prefix + str(idx), + lifetime=lifetime, + ) for idx, bundles in enumerate(pg_scheme) ] @@ -134,7 +156,9 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="c def extract_pg_from_exist( - resource_pools: dict[str, RayResourcePool], src_role_names: list[str], resource_pool: RayResourcePool + resource_pools: dict[str, RayResourcePool], + src_role_names: list[str], + resource_pool: RayResourcePool, ) -> list: src_pgs = [ pg @@ -144,15 +168,19 @@ def extract_pg_from_exist( ] sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True) - sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True) + sorted_process_on_nodes = sorted( + [(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True + ) unsorted_pgs: list[tuple[int, PlacementGroup]] = [] searching_idx = 0 for request_process, original_idx in sorted_process_on_nodes: - assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node" - assert request_process <= sorted_src_pgs[searching_idx].bundle_count, ( - f"requesting {request_process} processes, bundle count cannot satisfy" - ) + assert searching_idx < len( + sorted_src_pgs + ), f"no enough nodes for request: searching {searching_idx} th node" + assert ( + request_process <= sorted_src_pgs[searching_idx].bundle_count + ), f"requesting {request_process} processes, bundle count cannot satisfy" unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx])) searching_idx += 1 @@ -161,9 +189,15 @@ def extract_pg_from_exist( def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool: assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not" - assert rp1.max_colocate_count == rp2.max_colocate_count, "Both RayResourcePool must has the same max_colocate_count" - assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node" - assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool" + assert ( + rp1.max_colocate_count == rp2.max_colocate_count + ), "Both RayResourcePool must has the same max_colocate_count" + assert ( + rp1.n_gpus_per_node == rp2.n_gpus_per_node + ), "Both RayResourcePool must has the same n_gpus_per_node" + assert ( + rp1.detached == rp2.detached + ), "Detached ResourcePool cannot be merged with non-detached ResourcePool" new_store = rp1.store + rp2.store @@ -228,12 +262,19 @@ def __call__( if sharing_with is not None: target_node_id = ray.get(sharing_with.get_node_id.remote()) visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) - options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)} - return self.cls.options(**options).remote(*self.args, cuda_visible_devices=visible_devices, **self.kwargs) + options = { + "scheduling_strategy": NodeAffinitySchedulingStrategy( + node_id=target_node_id, soft=False + ) + } + return self.cls.options(**options).remote( + *self.args, cuda_visible_devices=visible_devices, **self.kwargs + ) options = { "scheduling_strategy": PlacementGroupSchedulingStrategy( - placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx + placement_group=placement_group, + placement_group_bundle_index=placement_group_bundle_idx, ) } options.update(self._options) @@ -288,7 +329,9 @@ def __init__( """ super().__init__(resource_pool=resource_pool, **kwargs) self.ray_cls_with_init = ray_cls_with_init - self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix + self.name_prefix = ( + get_random_string(length=6) if name_prefix is None else name_prefix + ) self._ray_wait_register_center_timeout = ray_wait_register_center_timeout # Whether the WorkerGroup is a Colocate WorkerGroup created by FusedWorker. self.fused_worker_used = ray_cls_with_init.fused_worker_used @@ -298,18 +341,28 @@ def __init__( self.device_name = device_name self.profile_steps = kwargs.get("profile_steps", None) self.worker_nsight_options = kwargs.get("worker_nsight_options", None) - if self.worker_nsight_options is not None and self.worker_nsight_options["capture-range-end"] is None: - self.worker_nsight_options["capture-range-end"] = f"repeat-shutdown:{6 * len(self.profile_steps)}" + if ( + self.worker_nsight_options is not None + and self.worker_nsight_options["capture-range-end"] is None + ): + self.worker_nsight_options["capture-range-end"] = ( + f"repeat-shutdown:{6 * len(self.profile_steps)}" + ) if worker_names is not None and (not self.fused_worker_used): assert self._is_init_with_detached_workers self._worker_names = worker_names if self._is_init_with_detached_workers: - self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles) + self._init_with_detached_workers( + worker_names=worker_names, worker_handles=worker_handles + ) else: self._init_with_resource_pool( - resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached + resource_pool=resource_pool, + ray_cls_with_init=ray_cls_with_init, + bin_pack=bin_pack, + detached=detached, ) if ray_cls_with_init is not None: @@ -328,18 +381,28 @@ def _is_worker_alive(self, worker: ray.actor.ActorHandle): bool: True if the worker is alive, False otherwise """ worker_state_dict = get_actor(worker._actor_id.hex()) - return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False + return ( + worker_state_dict.get("state", "undefined") == "ALIVE" + if worker_state_dict is not None + else False + ) def _init_with_detached_workers(self, worker_names, worker_handles): # ray.get_actor holds a weak reference to the actor, which causes actors garbage collected unexpectedly # if we only hold spawn RayWorkerGroup. By passing actor handle explicitly, spawn RayWorkerGroup have # strong reference to these actors. # https://github.com/ray-project/ray/pull/45699 - workers = worker_handles if worker_handles else [ray.get_actor(name=name) for name in worker_names] + workers = ( + worker_handles + if worker_handles + else [ray.get_actor(name=name) for name in worker_names] + ) self._workers = workers self._world_size = len(worker_names) - def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached): + def _init_with_resource_pool( + self, resource_pool, ray_cls_with_init, bin_pack, detached + ): """Initialize the worker group by creating new workers from a resource pool. Args: @@ -353,7 +416,9 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d strategy = "PACK" if bin_pack: strategy = "STRICT_PACK" - pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name) + pgs = resource_pool.get_placement_groups( + strategy=strategy, device_name=self.device_name + ) world_size = resource_pool.world_size self._world_size = world_size # cia.add_kwarg("_world_size", world_size) @@ -362,7 +427,9 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d rank = -1 local_world_size = resource_pool.store[0] for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)): - assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the " + assert ( + local_world_size <= pg.bundle_count + ), f"when generating for {self.name_prefix}, for the " for local_rank in range(local_world_size): rank += 1 @@ -382,8 +449,12 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d import re cia_name = type(ray_cls_with_init.cls).__name__ - match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" - cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" + match = re.search( + r"ActorClass\(([^)]+)\)", cia_name + ) # ray.remote(Obj) -> "ActorClass(Obj)" + cia_name = ( + match.group(1) if match else cia_name + ) # "ActorClass(Obj)" -> "Obj" name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5 if self.profile_steps and self.device_name == "cuda": @@ -397,7 +468,9 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d } ) else: - ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name}) + ray_cls_with_init.update_options( + {"runtime_env": {"env_vars": env_vars}, "name": name} + ) if detached: ray_cls_with_init.update_options({"lifetime": "detached"}) @@ -418,7 +491,10 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d actor_name = f"{self.name_prefix}_register_center" start_time = time.time() - while time.time() - start_time < self._ray_wait_register_center_timeout: + while ( + time.time() - start_time + < self._ray_wait_register_center_timeout + ): if actor_name in list_named_actors(): register_center_actor = ray.get_actor(actor_name) break @@ -445,8 +521,13 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d "`trainer.ray_wait_register_center_timeout`." ) - rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote()) - self._master_addr, self._master_port = rank_zero_info["MASTER_ADDR"], rank_zero_info["MASTER_PORT"] + rank_zero_info = ray.get( + register_center_actor.get_rank_zero_info.remote() + ) + self._master_addr, self._master_port = ( + rank_zero_info["MASTER_ADDR"], + rank_zero_info["MASTER_PORT"], + ) # print(f"rank_zero_info: {rank_zero_info}") # print(f"master_addr: {self._master_addr}, master_port: {self._master_port}") @@ -530,7 +611,9 @@ def spawn_fused(self, prefix_set): wg_dict = dict() for key in prefix_set: new_wg = deepcopy(self) - new_wg._bind_worker_method(self.ray_cls_with_init.cls.raw_cls_dict[key], func_generator) + new_wg._bind_worker_method( + self.ray_cls_with_init.cls.raw_cls_dict[key], func_generator + ) new_wg.sub_cls_name = key wg_dict[key] = new_wg return wg_dict @@ -545,7 +628,9 @@ def fuse(self, prefix_set): self.wg_dict = self.spawn(prefix_set) for role_name, role_wg in self.wg_dict.items(): setattr(self, role_name, role_wg) - self.method_names = self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) + self.method_names = self._bind_worker_method( + self.ray_cls_with_init.cls, func_generator + ) def _execute_remote_single_worker(self, worker, method_name: str, *args, **kwargs): """Execute a method on a single worker remotely. @@ -561,7 +646,9 @@ def _execute_remote_single_worker(self, worker, method_name: str, *args, **kwarg """ if self.fused_worker_used and method_name not in self.method_names: remote_call = getattr(worker, self.fused_worker_execute_fn_name) - return remote_call.remote(f"{self.sub_cls_name}_fwmn_{method_name}", *args, **kwargs) + return remote_call.remote( + f"{self.sub_cls_name}_fwmn_{method_name}", *args, **kwargs + ) # fused worker not used remote_call = getattr(worker, method_name) return remote_call.remote(*args, **kwargs) @@ -590,7 +677,9 @@ def execute_rank_zero_async(self, method_name: str, *args, **kwargs): Returns: Remote object reference to the method execution """ - return self._execute_remote_single_worker(self._workers[0], method_name, *args, **kwargs) + return self._execute_remote_single_worker( + self._workers[0], method_name, *args, **kwargs + ) def execute_rank_zero(self, method_name: str, *args, **kwargs): """Alias for execute_rank_zero_async. @@ -647,19 +736,28 @@ def execute_all_async(self, method_name: str, *args, **kwargs): # element in these lists to the corresponding worker # print(f"execute_all_async: method {method_name}({args}, {kwargs})") length = len(self._workers) - if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): - if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()): + if all(isinstance(arg, list) for arg in args) and all( + isinstance(kwarg, list) for kwarg in kwargs.values() + ): + if all(len(arg) == length for arg in args) and all( + len(kwarg) == length for kwarg in kwargs.values() + ): # print(f"splitting args and kwargs into {length} shards") result = [] for i in range(length): sliced_args = tuple(arg[i] for arg in args) sliced_kwargs = {k: v[i] for k, v in kwargs.items()} result.append( - self._execute_remote_single_worker(self._workers[i], method_name, *sliced_args, **sliced_kwargs) + self._execute_remote_single_worker( + self._workers[i], method_name, *sliced_args, **sliced_kwargs + ) ) return result - return [self._execute_remote_single_worker(worker, method_name, *args, **kwargs) for worker in self._workers] + return [ + self._execute_remote_single_worker(worker, method_name, *args, **kwargs) + for worker in self._workers + ] @property def master_address(self): @@ -694,7 +792,9 @@ def _bind_workers_method_to_parent(cls, key, user_defined_cls): for method_name in dir(user_defined_cls): try: method = getattr(user_defined_cls, method_name) - assert callable(method), f"{method_name} in {user_defined_cls} is not callable" + assert callable( + method + ), f"{method_name} in {user_defined_cls} is not callable" except Exception: # if it is a property, it will fail because Class doesn't have instance property continue @@ -710,7 +810,9 @@ async def async_func(self, *args, **kwargs): # dispatch to the actual worker return await getattr(self.worker_dict[key], name)(*args, **kwargs) - wrapper = async_func if inspect.iscoroutinefunction(method) else func # noqa: B023 + wrapper = ( + async_func if inspect.iscoroutinefunction(method) else func + ) # noqa: B023 return wrapper @@ -720,10 +822,13 @@ async def async_func(self, *args, **kwargs): setattr(func, MAGIC_ATTR, attrs) try: # bind direct rollout method to class without prefix - if attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD and "rollout" in key: - assert not hasattr(cls, method_name), ( - f"conflict direct rollout method {method_name} with role {key}" - ) + if ( + attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD + and "rollout" in key + ): + assert not hasattr( + cls, method_name + ), f"conflict direct rollout method {method_name} with role {key}" setattr(cls, method_name, func) print(f"bind role {key} method {method_name} to class {cls}") else: @@ -763,7 +868,9 @@ def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): worker_cls = _determine_fsdp_megatron_base_class( [cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()] ) - assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker" + assert issubclass( + worker_cls, Worker + ), f"worker_cls {worker_cls} should be a subclass of Worker" print(f"colocated worker base class {worker_cls}") for key, cls in class_dict.items(): @@ -784,7 +891,8 @@ def __init__(self): # when DISABLE_WORKER_INIT == 1 it will return immediately with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): self.worker_dict[key] = user_defined_cls( - *init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {}) + *init_args_dict[key].get("args", ()), + **init_args_dict[key].get("kwargs", {}), ) # now monkey-patch the methods from inner class to WorkerDict @@ -818,7 +926,9 @@ def create_colocated_worker_raw_cls(class_dict: dict[str, RayClassWithInitArgs]) The same as `FusedWorker.fused_worker_dict`, enables underlying class to access other underlying classes. """ - raw_cls_dict = {cls_name: _unwrap_ray_remote(cia.cls) for cls_name, cia in class_dict.items()} + raw_cls_dict = { + cls_name: _unwrap_ray_remote(cia.cls) for cls_name, cia in class_dict.items() + } init_args_dict = {cls_name: cia.args for cls_name, cia in class_dict.items()} init_kwargs_dict = {cls_name: cia.kwargs for cls_name, cia in class_dict.items()} cls_names = list(class_dict.keys()) @@ -842,8 +952,12 @@ def __init__(self, *args, **kwargs): strict=True, ): with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): - udc._get_ray_actor_cls_name = lambda x, name_renamed=class_name_renamed: name_renamed - udc._get_ray_method_prefix = lambda x, name_prefixed=cls_name: f"{name_prefixed}_" + udc._get_ray_actor_cls_name = ( + lambda x, name_renamed=class_name_renamed: name_renamed + ) + udc._get_ray_method_prefix = ( + lambda x, name_prefixed=cls_name: f"{name_prefixed}_" + ) # cls_name = "actor", "critic", udc = ActorWorker, CriticWorker self.fused_worker_dict[cls_name] = udc(*ud_args, **ud_kwargs) setattr(self, cls_name, self.fused_worker_dict[cls_name]) @@ -859,9 +973,9 @@ def _fuw_execute(self, method_name: str, *args, **kwargs): cls_name = names[0] method_name = names[1] - assert cls_name in self.fused_worker_dict, ( - f"calling {cls_name}'s {method_name}, but {cls_name} not in fused_worker_dict" - ) + assert ( + cls_name in self.fused_worker_dict + ), f"calling {cls_name}'s {method_name}, but {cls_name} not in fused_worker_dict" udc_method = getattr(self.fused_worker_dict[cls_name], method_name) return udc_method(*args, **kwargs) diff --git a/Agent0/executor_train/verl/verl/single_controller/ray/megatron.py b/Agent0/executor_train/verl/verl/single_controller/ray/megatron.py index b46fe44..69ab9e3 100644 --- a/Agent0/executor_train/verl/verl/single_controller/ray/megatron.py +++ b/Agent0/executor_train/verl/verl/single_controller/ray/megatron.py @@ -29,7 +29,12 @@ class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): so that the dispatcher can use it to dispatch data. """ - def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs): + def __init__( + self, + resource_pool: RayResourcePool, + ray_cls_with_init: RayClassWithInitArgs, + **kwargs, + ): """ Initialize the NVMegatronRayWorkerGroup. @@ -38,8 +43,12 @@ def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWi ray_cls_with_init (RayClassWithInitArgs): The Ray class with initialization arguments **kwargs: Additional keyword arguments to pass to the parent class """ - super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) - self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") + super().__init__( + resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs + ) + self._megatron_rank_info: DistRankInfo = self.execute_all_sync( + method_name="get_megatron_rank_info" + ) self._megatron_global_info: DistGlobalInfo = ray.get( self.execute_rank_zero_async(method_name="get_megatron_global_info") ) @@ -65,7 +74,9 @@ def __init__( **kwargs, ) self.init_megatron(default_megatron_kwargs=default_megatron_kwargs) - self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") + self._megatron_rank_info: DistRankInfo = self.execute_all_sync( + method_name="get_megatron_rank_info" + ) self._megatron_global_info: DistGlobalInfo = ray.get( self.execute_rank_zero_async(method_name="get_megatron_global_info") ) @@ -74,4 +85,7 @@ def init_megatron(self, default_megatron_kwargs: Optional[dict] = None): # after super, we will call init of each worker if not self._is_init_with_detached_workers: # only init_megatron if the WorkerGroup is created from scratch - self.execute_all_sync(method_name="init_megatron", default_megatron_kwargs=default_megatron_kwargs) + self.execute_all_sync( + method_name="init_megatron", + default_megatron_kwargs=default_megatron_kwargs, + ) diff --git a/Agent0/executor_train/verl/verl/third_party/sglang/parallel_state.py b/Agent0/executor_train/verl/verl/third_party/sglang/parallel_state.py index cdec743..e8a5842 100644 --- a/Agent0/executor_train/verl/verl/third_party/sglang/parallel_state.py +++ b/Agent0/executor_train/verl/verl/third_party/sglang/parallel_state.py @@ -57,7 +57,9 @@ def initialize_parallel_state( # Use the world_size set by TORCHRUN world_size = int(os.getenv("WORLD_SIZE", "-1")) assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" - init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) + init_distributed_environment( + world_size, rank, distributed_init_method, local_rank, backend + ) if torch.distributed.get_world_size() > 1: # NOTE: build a separate inference group with infer tp & micro dp initialize_model_parallel_for_sglang( @@ -65,7 +67,9 @@ def initialize_parallel_state( num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, ) else: - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + initialize_model_parallel( + tensor_model_parallel_size, pipeline_model_parallel_size, backend + ) # NOTE(linjunrong): After init SGLang rollout using class EngineFragment, user should always remember to call @@ -86,7 +90,9 @@ def ensure_model_parallel_initialized( # get the backend of _DEVICE_WORLD_GROUP backend = backend or torch.distributed.get_backend(get_world_group().device_group) if not model_parallel_is_initialized(): - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + initialize_model_parallel( + tensor_model_parallel_size, pipeline_model_parallel_size, backend + ) return assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( @@ -140,7 +146,9 @@ def initialize_model_parallel_for_sglang( assert _TP is None, "tensor model parallel group is already initialized" group_ranks = [] for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + ranks = range( + i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size + ) group_ranks.append(ranks) _TP = init_model_parallel_group( group_ranks=group_ranks, @@ -158,15 +166,22 @@ def initialize_model_parallel_for_sglang( # Build the inference tp groups # train_tp = train_tensor_parallel_size - train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size + train_tp = ( + num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size + ) # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size assert _TP is None, "tensor model parallel group is already initialized" group_ranks = [] - for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): + for i in range( + num_tensor_model_parallel_groups + // num_tensor_model_parallel_groups_per_train_tp + ): start = train_tp * i end = train_tp * (i + 1) for j in range(num_tensor_model_parallel_groups_per_train_tp): - ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) + ranks = list( + range(start, end, num_tensor_model_parallel_groups_per_train_tp) + ) for i in range(len(ranks)): ranks[i] += j group_ranks.append(ranks) @@ -197,7 +212,9 @@ def initialize_model_parallel_for_sglang( ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) # pipeline parallel does not need custom allreduce - _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + _PP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False + ) ps._PP = _PP # for verl @@ -234,7 +251,9 @@ def initialize_model_parallel( # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() - backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) + backend = backend or torch.distributed.get_backend( + ps.get_world_group().device_group + ) # NOTE(sgm) we don't assert world_size == tp * pp # DP is not managed by vllm but by the VeRL WorkerGroup @@ -251,7 +270,9 @@ def initialize_model_parallel( assert _TP is None, "tensor model parallel group is already initialized" group_ranks = [] for i in range(num_tensor_model_parallel_groups): - ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) + ranks = list( + range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + ) group_ranks.append(ranks) # message queue broadcaster is only used in tensor model parallel group @@ -280,7 +301,12 @@ def initialize_model_parallel( if ps._TP is not None: _PP = ps._TP else: - _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + _PP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, + ) ps._PP = _PP diff --git a/Agent0/executor_train/verl/verl/tools/base_tool.py b/Agent0/executor_train/verl/verl/tools/base_tool.py index 9a1189d..e9a85d2 100644 --- a/Agent0/executor_train/verl/verl/tools/base_tool.py +++ b/Agent0/executor_train/verl/verl/tools/base_tool.py @@ -38,7 +38,12 @@ def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): self.tool_schema = tool_schema or self.get_openai_tool_schema() assert self.tool_schema is not None, "Tool schema is not set!" self.name = self.tool_schema.function.name - print(json.dumps(self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), indent=2)) + print( + json.dumps( + self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), + indent=2, + ) + ) def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: return self.tool_schema @@ -58,7 +63,9 @@ async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: return instance_id @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: """Execute the tool. Args: diff --git a/Agent0/executor_train/verl/verl/tools/geo3k_tool.py b/Agent0/executor_train/verl/verl/tools/geo3k_tool.py index 6ffd6fb..d3a4f33 100644 --- a/Agent0/executor_train/verl/verl/tools/geo3k_tool.py +++ b/Agent0/executor_train/verl/verl/tools/geo3k_tool.py @@ -64,7 +64,12 @@ def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: return self.tool_schema - async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: + async def create( + self, + instance_id: Optional[str] = None, + ground_truth: Optional[str] = None, + **kwargs, + ) -> str: if instance_id is None: instance_id = str(uuid4()) self._instance_dict[instance_id] = { @@ -75,14 +80,18 @@ async def create(self, instance_id: Optional[str] = None, ground_truth: Optional return instance_id, None @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: answer = parameters.get("answer", "") if not isinstance(answer, str): answer = str(answer) self._instance_dict[instance_id]["response"] = answer reward = await self.calc_reward(instance_id) # penalty for non improved answer submission - tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + tool_reward = ( + 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + ) # update the reward self._instance_dict[instance_id]["reward"] = reward return f"Current parsed {answer=} {reward=}", tool_reward, {} diff --git a/Agent0/executor_train/verl/verl/tools/gsm8k_tool.py b/Agent0/executor_train/verl/verl/tools/gsm8k_tool.py index f6d8913..bc0eea6 100644 --- a/Agent0/executor_train/verl/verl/tools/gsm8k_tool.py +++ b/Agent0/executor_train/verl/verl/tools/gsm8k_tool.py @@ -64,7 +64,12 @@ def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: return self.tool_schema - async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: + async def create( + self, + instance_id: Optional[str] = None, + ground_truth: Optional[str] = None, + **kwargs, + ) -> str: if instance_id is None: instance_id = str(uuid4()) self._instance_dict[instance_id] = { @@ -75,7 +80,9 @@ async def create(self, instance_id: Optional[str] = None, ground_truth: Optional return instance_id @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: answer = parameters.get("answer", "") if not isinstance(answer, str): answer = str(answer) @@ -87,7 +94,9 @@ async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) reward = await self.calc_reward(instance_id) # penalty for non improved answer submission - tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + tool_reward = ( + 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + ) # update the reward self._instance_dict[instance_id]["reward"] = reward diff --git a/Agent0/executor_train/verl/verl/tools/mcp_base_tool.py b/Agent0/executor_train/verl/verl/tools/mcp_base_tool.py index dacd18e..981bf2d 100644 --- a/Agent0/executor_train/verl/verl/tools/mcp_base_tool.py +++ b/Agent0/executor_train/verl/verl/tools/mcp_base_tool.py @@ -63,7 +63,9 @@ async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]: err_msg = "" try: - call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout) + call_tool_result = await ClientManager.call_tool( + self.name, parameters, self.timeout + ) except ClientError as e: err_msg = f"\n Tool call failed: {e}" except ConnectionError as e: @@ -71,16 +73,22 @@ async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]: except Exception as e: err_msg = f"\n An unexpected error occurred: {e}" - logger.debug(f"Tool result for instance {instance_id} with tool {self.name}: {call_tool_result.content}") + logger.debug( + f"Tool result for instance {instance_id} with tool {self.name}: {call_tool_result.content}" + ) result, metadata = self._parse_tool_result(call_tool_result.content) metadata["api_request_error"] += err_msg return result, metadata @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: if self.name == "" or self.name is None or parameters is None: error_msg = "Error: 'parameters' is missing or empty." - logger.error(f"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}") + logger.error( + f"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}" + ) return json.dumps({"result": error_msg}), 0.0, {} try: @@ -112,5 +120,7 @@ async def release(self, instance_id: str, **kwargs) -> None: del self._instance_dict[instance_id] def _parse_tool_result(self, content: list) -> tuple[str, dict]: - tools_content = [part.text for part in filter(lambda x: x.type == "text", content)] + tools_content = [ + part.text for part in filter(lambda x: x.type == "text", content) + ] return " ".join(tools_content), {} diff --git a/Agent0/executor_train/verl/verl/tools/sandbox_fusion_tools.py b/Agent0/executor_train/verl/verl/tools/sandbox_fusion_tools.py index c3a2748..5819e85 100644 --- a/Agent0/executor_train/verl/verl/tools/sandbox_fusion_tools.py +++ b/Agent0/executor_train/verl/verl/tools/sandbox_fusion_tools.py @@ -63,12 +63,16 @@ def get_current_count(self): class ExecutionWorker: def __init__(self, enable_global_rate_limit=True, rate_limit=10): - self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None + self.rate_limit_worker = ( + self._init_rate_limit(rate_limit) if enable_global_rate_limit else None + ) def _init_rate_limit(self, rate_limit): # TODO validation for rate_limit # A Singleton Rate Limitor - return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) + return TokenBucketWorker.options( + name="rate-limiter", get_if_exists=True + ).remote(rate_limit) def ping(self): return True @@ -85,13 +89,18 @@ def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: def init_execution_pool( - num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode + num_workers: int, + enable_global_rate_limit=True, + rate_limit=10, + mode: PoolMode = PoolMode.ThreadMode, ): if mode == PoolMode.ThreadMode: return ( ray.remote(ExecutionWorker) .options(max_concurrency=num_workers) - .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + .remote( + enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit + ) ) else: raise NotImplementedError("Process mode is not implemented yet") @@ -152,7 +161,12 @@ def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: return self.tool_schema - async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: + async def create( + self, + instance_id: Optional[str] = None, + ground_truth: Optional[str] = None, + **kwargs, + ) -> str: if instance_id is None: instance_id = str(uuid4()) self._instance_dict[instance_id] = { @@ -163,25 +177,38 @@ async def create(self, instance_id: Optional[str] = None, ground_truth: Optional return instance_id @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: code = parameters.get("code", "") timeout = parameters.get("timeout", self.default_timeout) language = parameters.get("language", self.default_language) if not isinstance(code, str): code = str(code) - result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) + result = await self.execution_pool.execute.remote( + self.execute_code, instance_id, code, timeout, language + ) # sandbox has no score or metrics, use Nones return result, None, None def execute_code(self, instance_id, code, timeout=30, language="python"): result_status, metadata = _process_single_case( - 0, None, None, self.sandbox_fusion_url, code, timeout, self.memory_limit_mb, language + 0, + None, + None, + self.sandbox_fusion_url, + code, + timeout, + self.memory_limit_mb, + language, ) # we should always expect this since we don't have correct answer if metadata["run_status"] == "Finished": actual_output = metadata["stdout"] + metadata["stderr"] - logger.debug(f"actual_output from sandbox fusion: {actual_output},{instance_id}") + logger.debug( + f"actual_output from sandbox fusion: {actual_output},{instance_id}" + ) return actual_output else: return "no stdout here" diff --git a/Agent0/executor_train/verl/verl/tools/schemas.py b/Agent0/executor_train/verl/verl/tools/schemas.py index c0c65a3..6e08bda 100644 --- a/Agent0/executor_train/verl/verl/tools/schemas.py +++ b/Agent0/executor_train/verl/verl/tools/schemas.py @@ -78,7 +78,10 @@ def from_openai_function_parsed_schema( arguments = {} has_decode_error = True - return OpenAIFunctionCallSchema(name=parsed_schema.name, arguments=arguments), has_decode_error + return ( + OpenAIFunctionCallSchema(name=parsed_schema.name, arguments=arguments), + has_decode_error, + ) class OpenAIFunctionToolCall(BaseModel): diff --git a/Agent0/executor_train/verl/verl/tools/search_tool.py b/Agent0/executor_train/verl/verl/tools/search_tool.py index 3cc6cda..bb20716 100644 --- a/Agent0/executor_train/verl/verl/tools/search_tool.py +++ b/Agent0/executor_train/verl/verl/tools/search_tool.py @@ -75,11 +75,15 @@ class SearchExecutionWorker: """Worker for executing search operations with optional rate limiting.""" def __init__(self, enable_global_rate_limit=True, rate_limit=10): - self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None + self.rate_limit_worker = ( + self._init_rate_limit(rate_limit) if enable_global_rate_limit else None + ) def _init_rate_limit(self, rate_limit): """Initialize singleton rate limiter.""" - return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) + return TokenBucketWorker.options( + name="rate-limiter", get_if_exists=True + ).remote(rate_limit) def ping(self): """Health check method.""" @@ -101,14 +105,19 @@ def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: def init_search_execution_pool( - num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode + num_workers: int, + enable_global_rate_limit=True, + rate_limit=10, + mode: PoolMode = PoolMode.ThreadMode, ): """Initialize search execution pool.""" if mode == PoolMode.ThreadMode: return ( ray.remote(SearchExecutionWorker) .options(max_concurrency=num_workers) - .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) + .remote( + enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit + ) ) else: raise NotImplementedError("Process mode is not implemented yet") @@ -174,7 +183,9 @@ def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): # Retrieval service configuration self.retrieval_service_url = config.get("retrieval_service_url") - assert self.retrieval_service_url, "Configuration must include 'retrieval_service_url'" + assert ( + self.retrieval_service_url + ), "Configuration must include 'retrieval_service_url'" self.topk = config.get("topk", 3) if self.retrieval_service_url == "": raise ValueError("retrieval_service_url is not set") @@ -202,7 +213,14 @@ async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: } return instance_id - def execute_search(self, instance_id: str, query_list: list, retrieval_service_url: str, topk: int, timeout: int): + def execute_search( + self, + instance_id: str, + query_list: list, + retrieval_service_url: str, + topk: int, + timeout: int, + ): """Execute search operation using retrieval service. Args: @@ -226,7 +244,9 @@ def execute_search(self, instance_id: str, query_list: list, retrieval_service_u return result_text, metadata @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + async def execute( + self, instance_id: str, parameters: dict[str, Any], **kwargs + ) -> tuple[str, float, dict]: """Execute the search tool. Args: @@ -242,14 +262,21 @@ async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) query_list_from_params = parameters.get("query_list") if not query_list_from_params or not isinstance(query_list_from_params, list): - error_msg = "Error: 'query_list' is missing, empty, or not a list in parameters." + error_msg = ( + "Error: 'query_list' is missing, empty, or not a list in parameters." + ) logger.error(f"[SearchTool] {error_msg} Received parameters: {parameters}") return json.dumps({"result": error_msg}), 0.0, {} # Execute search using Ray execution pool try: result_text, metadata = await self.execution_pool.execute.remote( - self.execute_search, instance_id, query_list_from_params, self.retrieval_service_url, self.topk, timeout + self.execute_search, + instance_id, + query_list_from_params, + self.retrieval_service_url, + self.topk, + timeout, ) # Store results in instance dictionary diff --git a/Agent0/executor_train/verl/verl/tools/utils/mcp_clients/McpClientManager.py b/Agent0/executor_train/verl/verl/tools/utils/mcp_clients/McpClientManager.py index ee5fe31..bf747e4 100644 --- a/Agent0/executor_train/verl/verl/tools/utils/mcp_clients/McpClientManager.py +++ b/Agent0/executor_train/verl/verl/tools/utils/mcp_clients/McpClientManager.py @@ -42,7 +42,10 @@ async def initialize(self, config_path, rate_limit: float = 10.0): for server_name in servers.keys(): server = servers[server_name] if "auth_token" in server: - transport = SSETransport(url=server["url"], headers={"Authorization": f"Bearer {server['auth_token']}"}) + transport = SSETransport( + url=server["url"], + headers={"Authorization": f"Bearer {server['auth_token']}"}, + ) client = Client(transport) self.clients.append(client) else: diff --git a/Agent0/executor_train/verl/verl/tools/utils/search_r1_like_utils.py b/Agent0/executor_train/verl/verl/tools/utils/search_r1_like_utils.py index 23669e4..fc147db 100644 --- a/Agent0/executor_train/verl/verl/tools/utils/search_r1_like_utils.py +++ b/Agent0/executor_train/verl/verl/tools/utils/search_r1_like_utils.py @@ -92,7 +92,9 @@ def call_search_api( response.raise_for_status() # If successful (status code 2xx) - logger.info(f"{log_prefix}Search API call successful on attempt {attempt + 1}") + logger.info( + f"{log_prefix}Search API call successful on attempt {attempt + 1}" + ) return response.json(), None except requests.exceptions.ConnectionError as e: @@ -124,7 +126,11 @@ def call_search_api( # If loop finishes without returning success, return the last recorded error logger.error(f"{log_prefix}Search API call failed. Last error: {last_error}") - return None, last_error.replace(log_prefix, "API Call Failed: ") if last_error else "API Call Failed after retries" + return None, ( + last_error.replace(log_prefix, "API Call Failed: ") + if last_error + else "API Call Failed after retries" + ) def _passages2string(retrieval_result): @@ -198,7 +204,9 @@ def perform_single_search_batch( "formatted_result": None, } - result_text = json.dumps({"result": "Search request failed or timed out after retries."}) + result_text = json.dumps( + {"result": "Search request failed or timed out after retries."} + ) if error_msg: metadata["status"] = "api_error" @@ -217,14 +225,18 @@ def perform_single_search_batch( for retrieval in raw_results: formatted = _passages2string(retrieval) pretty_results.append(formatted) - total_results += len(retrieval) if isinstance(retrieval, list) else 1 + total_results += ( + len(retrieval) if isinstance(retrieval, list) else 1 + ) final_result = "\n---\n".join(pretty_results) result_text = json.dumps({"result": final_result}) metadata["status"] = "success" metadata["total_results"] = total_results metadata["formatted_result"] = final_result - logger.info(f"Batch search: Successful, got {total_results} total results") + logger.info( + f"Batch search: Successful, got {total_results} total results" + ) else: result_text = json.dumps({"result": "No search results found."}) metadata["status"] = "no_results" @@ -237,7 +249,9 @@ def perform_single_search_batch( logger.error(f"Batch search: {error_msg}") else: metadata["status"] = "unknown_api_state" - result_text = json.dumps({"result": "Unknown API state (no response and no error message)."}) + result_text = json.dumps( + {"result": "Unknown API state (no response and no error message)."} + ) logger.error("Batch search: Unknown API state.") return result_text, metadata diff --git a/Agent0/executor_train/verl/verl/tools/utils/tool_registry.py b/Agent0/executor_train/verl/verl/tools/utils/tool_registry.py index 5c14d10..d7b821b 100644 --- a/Agent0/executor_train/verl/verl/tools/utils/tool_registry.py +++ b/Agent0/executor_train/verl/verl/tools/utils/tool_registry.py @@ -37,8 +37,14 @@ async def initialize_mcp_tool(tool_cls, tool_config) -> list: tool_list = [] mcp_servers_config_path = tool_config.mcp.mcp_servers_config_path - tool_selected_list = tool_config.mcp.tool_selected_list if "tool_selected_list" in tool_config.mcp else None - await ClientManager.initialize(mcp_servers_config_path, tool_config.config.rate_limit) + tool_selected_list = ( + tool_config.mcp.tool_selected_list + if "tool_selected_list" in tool_config.mcp + else None + ) + await ClientManager.initialize( + mcp_servers_config_path, tool_config.config.rate_limit + ) # Wait for MCP client to be ready max_retries = 10 retry_interval = 2 # seconds @@ -47,7 +53,9 @@ async def initialize_mcp_tool(tool_cls, tool_config) -> list: if tool_schemas: break if i < max_retries - 1: - logger.debug(f"Waiting for MCP client to be ready, attempt {i + 1}/{max_retries}") + logger.debug( + f"Waiting for MCP client to be ready, attempt {i + 1}/{max_retries}" + ) await asyncio.sleep(retry_interval) else: raise RuntimeError("Failed to initialize MCP tools after maximum retries") @@ -91,8 +99,12 @@ def initialize_tools_from_config(tools_config_file): if tool_config.get("tool_schema", None) is None: tool_schema = None else: - tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) - tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict) + tool_schema_dict = OmegaConf.to_container( + tool_config.tool_schema, resolve=True + ) + tool_schema = OpenAIFunctionToolSchema.model_validate( + tool_schema_dict + ) tool = tool_cls( config=OmegaConf.to_container(tool_config.config, resolve=True), tool_schema=tool_schema, @@ -100,7 +112,9 @@ def initialize_tools_from_config(tools_config_file): tool_list.append(tool) case ToolType.MCP: loop = asyncio.get_event_loop() - mcp_tools = loop.run_until_complete(initialize_mcp_tool(tool_cls, tool_config)) + mcp_tools = loop.run_until_complete( + initialize_mcp_tool(tool_cls, tool_config) + ) tool_list.extend(mcp_tools) case _: raise NotImplementedError diff --git a/Agent0/executor_train/verl/verl/trainer/fsdp_sft_trainer.py b/Agent0/executor_train/verl/verl/trainer/fsdp_sft_trainer.py index 531ebab..02d8b62 100644 --- a/Agent0/executor_train/verl/verl/trainer/fsdp_sft_trainer.py +++ b/Agent0/executor_train/verl/verl/trainer/fsdp_sft_trainer.py @@ -43,8 +43,16 @@ import verl.utils.hdfs_io as hdfs_io from verl.utils.dataset import SFTDataset from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset -from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available -from verl.utils.distributed import destroy_global_process_group, initialize_global_process_group +from verl.utils.device import ( + get_device_id, + get_device_name, + is_cuda_available, + is_npu_available, +) +from verl.utils.distributed import ( + destroy_global_process_group, + initialize_global_process_group, +) from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( CPUOffloadPolicy, @@ -59,7 +67,10 @@ from verl.utils.profiler import log_gpu_memory_usage from verl.utils.py_functional import convert_to_regular_types from verl.utils.torch_dtypes import PrecisionType -from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup +from verl.utils.torch_functional import ( + get_cosine_schedule_with_warmup, + get_wsd_schedule_with_warmup, +) from verl.utils.tracking import Tracking from verl.utils.ulysses import ( gather_outpus_and_unpad, @@ -69,9 +80,19 @@ from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager if is_cuda_available: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + from flash_attn.bert_padding import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) elif is_npu_available: - from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + from transformers.integrations.npu_flash_attention import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) @@ -106,10 +127,14 @@ def __init__( self._normalize_config_bsz() # Set sequence parallel size - self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1) + self.config.ulysses_sequence_parallel_size = getattr( + self.config, "ulysses_sequence_parallel_size", 1 + ) self.use_remove_padding = getattr(self.config, "use_remove_padding", False) if self.device_mesh.get_rank() == 0: - print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}") + print( + f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}" + ) print(f"Using remove padding: {self.use_remove_padding}") self._build_dataloader(train_dataset, val_dataset) @@ -122,17 +147,25 @@ def __init__( self.device_name = get_device_name() def _normalize_config_bsz(self): - dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) + dp_size = ( + self.device_mesh.size(0) + if not self.ulysses_device_mesh + else self.ulysses_device_mesh.size(0) + ) if self.device_mesh.get_rank() == 0: print(f"Normalize batch size by dp {dp_size}") - assert self.config.data.train_batch_size % dp_size == 0, ( - f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" - ) + assert ( + self.config.data.train_batch_size % dp_size == 0 + ), f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" self.config.data.train_batch_size //= dp_size - assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0 + assert ( + self.config.data.train_batch_size + % self.config.data.micro_batch_size_per_gpu + == 0 + ) def _build_dataloader(self, train_dataset, val_dataset): # build dataset @@ -147,8 +180,12 @@ def _build_dataloader(self, train_dataset, val_dataset): rank = self.ulysses_device_mesh.get_local_rank("dp") world_size = self.ulysses_device_mesh.size(0) if self.ulysses_device_mesh.get_rank() == 0: - print(f"Using SP rank {rank} and size {world_size} for data distribution") - print("Each SP rank gets different data, but the same data WITHIN the same rank") + print( + f"Using SP rank {rank} and size {world_size} for data distribution" + ) + print( + "Each SP rank gets different data, but the same data WITHIN the same rank" + ) else: rank = self.device_mesh.get_rank() world_size = self.device_mesh.size() @@ -156,7 +193,11 @@ def _build_dataloader(self, train_dataset, val_dataset): print(f"Using FSDP rank {rank} and size {world_size} for data distribution") self.train_sampler = DistributedSampler( - self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True + self.train_dataset, + shuffle=True, + num_replicas=world_size, + rank=rank, + drop_last=True, ) self.train_dataloader = DataLoader( dataset=self.train_dataset, @@ -168,7 +209,11 @@ def _build_dataloader(self, train_dataset, val_dataset): ) self.val_sampler = DistributedSampler( - self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True + self.val_dataset, + shuffle=False, + num_replicas=world_size, + rank=rank, + drop_last=True, ) self.val_dataloader = DataLoader( dataset=self.val_dataset, @@ -183,7 +228,9 @@ def _build_model_optimizer(self): # TODO (zhangchi.usc1992): # 1. support pretrain from random weights # 2. support init directly from sharded weights - local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) + local_model_path = copy_to_local( + src=self.config.model.partial_pretrain, verbose=True + ) if self.config.model.get("external_lib", None) is not None: # This is used to import external_lib into the huggingface systems @@ -197,14 +244,18 @@ def _build_model_optimizer(self): torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") torch_dtype = PrecisionType.to_dtype(torch_dtype) # load config first - config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) + config = AutoConfig.from_pretrained( + local_model_path, trust_remote_code=trust_remote_code + ) self.model_config = config if hasattr(self.model_config, "max_position_embeddings"): self.model_config.max_position_embeddings = max( self.model_config.max_position_embeddings, self.config.data.max_length ) if self.config.ulysses_sequence_parallel_size > 1: - assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" + assert ( + self.use_remove_padding + ), "Sequence parallel is only supported when remove_padding is enabled" # This may be very large init_context = get_init_weight_context_manager( @@ -220,14 +271,22 @@ def _build_model_optimizer(self): trust_remote_code=trust_remote_code, ) - if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1: + if ( + self.use_remove_padding + or self.config.ulysses_sequence_parallel_size > 1 + ): from verl.models.transformers.monkey_patch import apply_monkey_patch - apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size) + apply_monkey_patch( + model=self.model, + ulysses_sp_size=self.config.ulysses_sequence_parallel_size, + ) # Apply Liger kernel if use_liger is enabled if self.config.model.get("use_liger", False): - from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance + from liger_kernel.transformers.monkey_patch import ( + _apply_liger_kernel_to_instance, + ) _apply_liger_kernel_to_instance(model=self.model) @@ -238,18 +297,24 @@ def _build_model_optimizer(self): "task_type": TaskType.CAUSAL_LM, "r": self.config.model.lora_rank, "lora_alpha": self.config.model.lora_alpha, - "target_modules": convert_to_regular_types(self.config.model.target_modules), + "target_modules": convert_to_regular_types( + self.config.model.target_modules + ), "bias": "none", } self.model = get_peft_model(self.model, LoraConfig(**lora_config)) if self.config.model.enable_gradient_checkpointing: - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + self.model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) log_gpu_memory_usage("After model allocation", logger=logger) mixed_precision = MixedPrecision( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, ) auto_wrap_policy = get_fsdp_wrap_policy( @@ -263,7 +328,9 @@ def _build_model_optimizer(self): if not self.config.model.fsdp_config.cpu_offload: cpu_offload = None else: - cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params) + cpu_offload = CPUOffload( + offload_params=self.config.model.fsdp_config.offload_params + ) fsdp_strategy = self.config.model.strategy if fsdp_strategy == "fsdp": @@ -281,9 +348,13 @@ def _build_model_optimizer(self): forward_prefetch=False, ) elif fsdp_strategy == "fsdp2": - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + assert ( + CPUOffloadPolicy is not None + ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + cast_forward_inputs=True, ) fsdp_kwargs = { @@ -294,7 +365,9 @@ def _build_model_optimizer(self): } full_state = self.model.state_dict() apply_fsdp2(self.model, fsdp_kwargs, self.config.model.fsdp_config) - fsdp2_load_full_state_dict(self.model, full_state, self.device_mesh, cpu_offload) + fsdp2_load_full_state_dict( + self.model, full_state, self.device_mesh, cpu_offload + ) self.fsdp_model = self.model else: raise NotImplementedError(f"not implement {fsdp_strategy}") @@ -321,20 +394,29 @@ def _build_model_optimizer(self): num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio) - if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine": + if ( + not hasattr(self.config.optim, "lr_scheduler") + or self.config.optim.lr_scheduler == "cosine" + ): self.lr_scheduler = get_cosine_schedule_with_warmup( - optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps + optimizer=self.optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=self.total_steps, ) elif self.config.optim.lr_scheduler == "wsd": self.lr_scheduler = get_wsd_schedule_with_warmup( - optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps + optimizer=self.optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=self.total_steps, ) else: raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}") def _compute_loss_and_backward(self, batch, do_backward=True): """Compute loss with optional sequence parallelism and remove padding features""" - use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 + use_sp = ( + self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 + ) # Move inputs to GPU and prepare loss mask input_ids = batch["input_ids"].to(self.device_name) @@ -345,12 +427,17 @@ def _compute_loss_and_backward(self, batch, do_backward=True): # Context manager for sequence parallel if needed context = self.sharding_manager if use_sp else nullcontext() - with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): + with context, torch.autocast( + device_type=self.device_name, dtype=torch.bfloat16 + ): if not use_sp: # Standard forward pass without sequence parallel labels = input_ids[:, 1:].contiguous() output = self.fsdp_model( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, ) logits = output.logits @@ -379,19 +466,30 @@ def _compute_loss_and_backward(self, batch, do_backward=True): # Unpad position_ids to align rotary position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices, ).transpose(0, 1) # Pad and slice inputs for sequence parallelism - input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ( + ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad, + sp_size=get_ulysses_sequence_parallel_world_size(), + ) ) # For computing loss - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + input_ids_rmpad_rolled = torch.roll( + input_ids_rmpad, shifts=-1, dims=1 + ) # (1, total_nnz) input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size() + input_ids_rmpad_rolled, + None, + get_ulysses_sequence_parallel_world_size(), ) - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( + 0 + ) # ((total_nnz / sp) + pad) # Forward pass output = self.fsdp_model( @@ -406,11 +504,16 @@ def _compute_loss_and_backward(self, batch, do_backward=True): input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device) loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) # Gather and unpad for sequence parallelism - loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) + loss = gather_outpus_and_unpad( + loss, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) # This is the loss collected from all ulysses ranks full_loss = pad_input( - hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + hidden_states=loss.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, ) full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss full_loss = full_loss.reshape(-1) @@ -421,7 +524,11 @@ def _compute_loss_and_backward(self, batch, do_backward=True): if self.config.data.balance_dp_token: torch.distributed.all_reduce(valid_token_this_rank) - dp_size = self.ulysses_device_mesh.size("dp") if use_sp else torch.distributed.get_world_size() + dp_size = ( + self.ulysses_device_mesh.size("dp") + if use_sp + else torch.distributed.get_world_size() + ) else: dp_size = 1 @@ -448,9 +555,13 @@ def training_step(self, batch: TensorDict): step_loss += loss.item() if self.config.model.strategy == "fsdp": - grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) + grad_norm = self.fsdp_model.clip_grad_norm_( + max_norm=self.config.optim.clip_grad + ) elif self.config.model.strategy == "fsdp2": - grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad) + grad_norm = fsdp2_clip_grad_norm_( + self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad + ) else: raise NotImplementedError(f"not implement {self.config.model.strategy}") @@ -493,7 +604,9 @@ def validation_step(self, batch: TensorDict): def save_checkpoint(self, step): # save checkpoint - path = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}") + path = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{step}" + ) fsdp_strategy = self.config.model.strategy if fsdp_strategy == "fsdp": @@ -501,7 +614,9 @@ def save_checkpoint(self, step): from torch.distributed.fsdp import FullStateDictConfig, StateDictType cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg): + with FSDP.state_dict_type( + self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg + ): state_dict = self.fsdp_model.state_dict() # save huggingface model @@ -511,7 +626,10 @@ def save_checkpoint(self, step): self.tokenizer.save_pretrained(path) elif fsdp_strategy == "fsdp2": # FSDP2 checkpoint saving - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + ) # Get full state dict with FSDP2 options = StateDictOptions(full_state_dict=True, cpu_offload=True) @@ -529,7 +647,9 @@ def save_checkpoint(self, step): # Copy to HDFS if configured if self.device_mesh.get_rank() == 0 and self.config.trainer.default_hdfs_dir: hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) - hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True) + hdfs_io.copy( + src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True + ) torch.distributed.barrier() @@ -548,7 +668,9 @@ def fit(self): last_valid_metric = None # compute the total training steps. # the total training steps in SFT is mainly for early exit - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + total_training_steps = ( + len(self.train_dataloader) * self.config.trainer.total_epochs + ) if self.config.trainer.total_training_steps is not None: total_training_steps = self.config.trainer.total_training_steps @@ -568,7 +690,9 @@ def fit(self): disable=rank != 0, ): global_step += 1 - data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name) + data = TensorDict( + data, batch_size=self.config.data.train_batch_size + ).to(self.device_name) metric = self.training_step(data) if rank == 0: tracking.log(data=metric, step=global_step) @@ -578,13 +702,16 @@ def fit(self): is_save_step = global_step % self.config.trainer.save_freq == 0 # early exit or validation step - if is_last_step or (self.config.trainer.test_freq > 0 and is_valid_step): + if is_last_step or ( + self.config.trainer.test_freq > 0 and is_valid_step + ): # Perform validation val_losses = [] for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to( - self.device_name - ) + val_data = TensorDict( + val_data, + batch_size=self.config.data.micro_batch_size_per_gpu, + ).to(self.device_name) val_loss = self.validation_step(val_data) val_losses.append(val_loss) if rank == 0: @@ -607,7 +734,9 @@ def run_sft(config): device_name = get_device_name() local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) + device_mesh = init_device_mesh( + device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",) + ) dp_size = world_size // config.ulysses_sequence_parallel_size ulysses_device_mesh = init_device_mesh( device_type=device_name, @@ -618,7 +747,9 @@ def run_sft(config): from verl.utils import hf_tokenizer local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) - tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) + tokenizer = hf_tokenizer( + local_model_path, trust_remote_code=config.model.trust_remote_code + ) train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) @@ -648,7 +779,9 @@ def create_sft_dataset(data_paths, data_config, tokenizer): if data_config.custom_cls.get("path", None): from verl.utils.import_utils import load_extern_type - dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + dataset_cls = load_extern_type( + data_config.custom_cls.path, data_config.custom_cls.name + ) # Then check if multi-turn dataset should be used elif data_config.get("multiturn", {}).get("enable", False): dataset_cls = MultiTurnSFTDataset @@ -657,7 +790,9 @@ def create_sft_dataset(data_paths, data_config, tokenizer): dataset_cls = SFTDataset # Create datasets based on the selected class - dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config) + dataset = dataset_cls( + parquet_files=data_paths, tokenizer=tokenizer, config=data_config + ) return dataset diff --git a/Agent0/executor_train/verl/verl/trainer/main_eval.py b/Agent0/executor_train/verl/verl/trainer/main_eval.py index 0a5c581..1eefa9a 100644 --- a/Agent0/executor_train/verl/verl/trainer/main_eval.py +++ b/Agent0/executor_train/verl/verl/trainer/main_eval.py @@ -38,7 +38,9 @@ def process_item(reward_fn, data_source, response_lst, reward_data): @hydra.main(config_path="config", config_name="evaluation", version_base=None) def main(config): - local_path = copy_to_local(config.data.path, use_shm=config.data.get("use_shm", False)) + local_path = copy_to_local( + config.data.path, use_shm=config.data.get("use_shm", False) + ) dataset = pd.read_parquet(local_path) responses = dataset[config.data.response_key] data_sources = dataset[config.data.data_source_key] @@ -56,7 +58,10 @@ def main(config): # Create remote tasks remote_tasks = [ - process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total) + process_item.remote( + compute_score, data_sources[i], responses[i], reward_model_data[i] + ) + for i in range(total) ] # Process results as they come in diff --git a/Agent0/executor_train/verl/verl/trainer/main_generation.py b/Agent0/executor_train/verl/verl/trainer/main_generation.py index b8174ad..a021f1a 100644 --- a/Agent0/executor_train/verl/verl/trainer/main_generation.py +++ b/Agent0/executor_train/verl/verl/trainer/main_generation.py @@ -32,7 +32,11 @@ from verl import DataProto from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) from verl.utils import hf_tokenizer from verl.utils.fs import copy_to_local from verl.utils.hdfs_io import makedirs @@ -49,7 +53,9 @@ def run_generation(config) -> None: if not ray.is_initialized(): # this is for local ray cluster ray.init( - runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}, + runtime_env={ + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"} + }, num_cpus=config.ray_init.num_cpus, ) @@ -58,7 +64,9 @@ def run_generation(config) -> None: @ray.remote(num_cpus=1) def main_task(config): - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + pprint( + OmegaConf.to_container(config, resolve=True) + ) # resolve=True will eval symbol values OmegaConf.resolve(config) local_path = copy_to_local(config.model.path) @@ -79,8 +87,12 @@ def main_task(config): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") - resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) + ray_cls_with_init = RayClassWithInitArgs( + cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout" + ) + resource_pool = RayResourcePool( + process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes + ) wg = RayWorkerGroup( resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, @@ -95,7 +107,9 @@ def main_task(config): for batch_idx in range(num_batch): print(f"[{batch_idx + 1}/{num_batch}] Start to process.") - batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size] + batch_chat_lst = chat_lst[ + batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size + ] inputs = tokenizer.apply_chat_template( batch_chat_lst, add_generation_prompt=True, @@ -109,7 +123,11 @@ def main_task(config): input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] position_ids = compute_position_id_with_mask(attention_mask) - batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} + batch_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } data = DataProto.from_dict(batch_dict) data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size) @@ -124,9 +142,15 @@ def main_task(config): for i in range(len(output)): data_item = output[i] prompt_length = data_item.batch["prompts"].shape[-1] - valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() - valid_response_ids = data_item.batch["responses"][:valid_response_length] - response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True) + valid_response_length = data_item.batch["attention_mask"][ + prompt_length: + ].sum() + valid_response_ids = data_item.batch["responses"][ + :valid_response_length + ] + response_str = tokenizer.decode( + valid_response_ids, skip_special_tokens=True + ) output_texts.append(response_str) output_lst[n_sample].extend(output_texts) diff --git a/Agent0/executor_train/verl/verl/trainer/main_ppo.py b/Agent0/executor_train/verl/verl/trainer/main_ppo.py index 2a0b21d..b64449f 100644 --- a/Agent0/executor_train/verl/verl/trainer/main_ppo.py +++ b/Agent0/executor_train/verl/verl/trainer/main_ppo.py @@ -67,7 +67,9 @@ def run_ppo(config) -> None: and OmegaConf.select(config.trainer, "profile_steps") is not None and len(OmegaConf.select(config.trainer, "profile_steps")) > 0 ): - nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) + nsight_options = OmegaConf.to_container( + config.trainer.controller_nsight_options + ) runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() else: runner = TaskRunner.remote() @@ -114,7 +116,8 @@ def run(self, config): # Download the checkpoint from HDFS to the local machine. # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on local_path = copy_to_local( - config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + config.actor_rollout_ref.model.path, + use_shm=config.actor_rollout_ref.model.get("use_shm", False), ) # Instantiate the tokenizer and processor. @@ -123,7 +126,9 @@ def run(self, config): trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) # Used for multimodal LLM, could be None - processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + processor = hf_processor( + local_path, trust_remote_code=trust_remote_code, use_fast=True + ) # Version validation for vllm. if config.actor_rollout_ref.rollout.name in ["vllm"]: @@ -131,13 +136,19 @@ def run(self, config): if config.actor_rollout_ref.model.get("lora_rank", 0) > 0: if not is_version_ge(pkg="vllm", minver="0.7.3"): - raise NotImplementedError("PPO LoRA is not supported before vllm 0.7.3") + raise NotImplementedError( + "PPO LoRA is not supported before vllm 0.7.3" + ) # Define worker classes based on the actor strategy. if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ( + ActorRolloutRefWorker, + AsyncActorRolloutRefWorker, + CriticWorker, + ) actor_rollout_cls = ( AsyncActorRolloutRefWorker @@ -149,7 +160,11 @@ def run(self, config): elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ( + ActorRolloutRefWorker, + AsyncActorRolloutRefWorker, + CriticWorker, + ) actor_rollout_cls = ( AsyncActorRolloutRefWorker @@ -197,24 +212,39 @@ def run(self, config): mapping[Role.RewardModel] = global_pool_id # Add a reference policy worker if KL loss or KL reward is used. - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + if ( + config.algorithm.use_kl_in_reward + or config.actor_rollout_ref.actor.use_kl_loss + ): role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) mapping[Role.RefPolicy] = global_pool_id # Load the reward manager for training and validation. reward_fn = load_reward_manager( - config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + config, + tokenizer, + num_examine=0, + **config.reward_model.get("reward_kwargs", {}), ) val_reward_fn = load_reward_manager( - config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + config, + tokenizer, + num_examine=1, + **config.reward_model.get("reward_kwargs", {}), + ) + resource_pool_manager = ResourcePoolManager( + resource_pool_spec=resource_pool_spec, mapping=mapping ) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) from verl.utils.dataset.rl_dataset import collate_fn # Create training and validation datasets. - train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) - val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) + train_dataset = create_rl_dataset( + config.data.train_files, config.data, tokenizer, processor, is_train=True + ) + val_dataset = create_rl_dataset( + config.data.val_files, config.data, tokenizer, processor, is_train=False + ) train_sampler = create_rl_sampler(config.data, train_dataset) # Initialize the PPO trainer. @@ -257,16 +287,25 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr # Check if a custom dataset class is specified in the data configuration # and if the path to the custom class is provided - if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + if ( + "custom_cls" in data_config + and data_config.custom_cls.get("path", None) is not None + ): # Dynamically load the custom dataset class - dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + dataset_cls = load_extern_type( + data_config.custom_cls.path, data_config.custom_cls.name + ) # Verify that the custom dataset class inherits from torch.utils.data.Dataset if not issubclass(dataset_cls, Dataset): raise TypeError( f"The custom dataset class '{data_config.custom_cls.name}' from " f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset" ) - elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train: + elif ( + "datagen" in data_config + and data_config.datagen.get("path", None) is not None + and is_train + ): # If a data generation strategy is specified, use the DynamicGenDataset class from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset @@ -302,7 +341,10 @@ def create_rl_sampler(data_config, dataset): import torch from torch.utils.data import RandomSampler, SequentialSampler - if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: + if ( + data_config.sampler is not None + and data_config.sampler.get("class_path", None) is not None + ): curriculum_class = load_extern_type( data_config.sampler.class_path, data_config.sampler.class_name, @@ -323,7 +365,9 @@ def create_rl_sampler(data_config, dataset): elif data_config.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(data_config.get("seed", 1)) - sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + sampler = RandomSampler( + data_source=dataset, generator=train_dataloader_generator + ) else: # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. sampler = SequentialSampler(data_source=dataset) diff --git a/Agent0/executor_train/verl/verl/trainer/ppo/core_algos.py b/Agent0/executor_train/verl/verl/trainer/ppo/core_algos.py index 5f02675..5e59129 100644 --- a/Agent0/executor_train/verl/verl/trainer/ppo/core_algos.py +++ b/Agent0/executor_train/verl/verl/trainer/ppo/core_algos.py @@ -184,8 +184,14 @@ def get_kl_controller(kl_ctrl): if kl_ctrl.type == "fixed": return FixedKLController(kl_coef=kl_ctrl.kl_coef) elif kl_ctrl.type == "adaptive": - assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" - return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) + assert ( + kl_ctrl.horizon > 0 + ), f"horizon must be larger than 0. Got {kl_ctrl.horizon}" + return AdaptiveKLController( + init_kl_coef=kl_ctrl.kl_coef, + target_kl=kl_ctrl.target_kl, + horizon=kl_ctrl.horizon, + ) else: raise NotImplementedError @@ -230,8 +236,14 @@ def compute_gae_advantage_return( lastgaelam_ = delta + gamma * lam * lastgaelam # skip values and TD-error on observation tokens - nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues - lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam + nextvalues = ( + values[:, t] * response_mask[:, t] + + (1 - response_mask[:, t]) * nextvalues + ) + lastgaelam = ( + lastgaelam_ * response_mask[:, t] + + (1 - response_mask[:, t]) * lastgaelam + ) advantages_reversed.append(lastgaelam) advantages = torch.stack(advantages_reversed[::-1], dim=1) @@ -300,7 +312,9 @@ def compute_grpo_outcome_advantage( raise ValueError(f"no score in prompt index: {idx}") for i in range(bsz): if norm_adv_by_std_in_grpo: - scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + scores[i] = (scores[i] - id2mean[index[i]]) / ( + id2std[index[i]] + epsilon + ) else: scores[i] = scores[i] - id2mean[index[i]] scores = scores.unsqueeze(-1) * response_mask @@ -308,7 +322,9 @@ def compute_grpo_outcome_advantage( return scores, scores -@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk") +@register_adv_est( + AdvantageEstimator.GRPO_PASSK +) # or simply: @register_adv_est("grpo_passk") def compute_grpo_passk_outcome_advantage( token_level_rewards: torch.Tensor, response_mask: torch.Tensor, @@ -468,9 +484,9 @@ def compute_rloo_outcome_advantage( for i in range(bsz): response_num = len(id2score[index[i]]) if response_num > 1: - scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / ( - response_num - 1 - ) + scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[ + index[i] + ] * response_num / (response_num - 1) scores = scores.unsqueeze(-1) * response_mask return scores, scores @@ -530,9 +546,14 @@ def compute_opo_outcome_advantage( return scores, scores -@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus") +@register_adv_est( + AdvantageEstimator.REINFORCE_PLUS_PLUS +) # or simply: @register_adv_est("reinforce_plus_plus") def compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + config: Optional[AlgoConfig] = None, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for REINFORCE++. @@ -599,7 +620,12 @@ def compute_remax_outcome_advantage( """ with torch.no_grad(): - returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + returns = ( + (token_level_rewards * response_mask) + .flip(dims=[-1]) + .cumsum(dim=-1) + .flip(dims=[-1]) + ) advantages = returns - reward_baselines.unsqueeze(-1) * response_mask return advantages, returns @@ -704,7 +730,9 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum loss = torch.mean(seq_losses) # seq-mean elif loss_agg_mode == "seq-mean-token-mean": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum( + loss_mask, dim=-1 + ) # token-mean loss = torch.mean(seq_losses) # seq-mean elif loss_agg_mode == "seq-mean-token-sum-norm": seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) @@ -780,7 +808,9 @@ def compute_policy_loss( clip_pg_losses1 = torch.maximum( pg_losses1, pg_losses2 ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) - pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) + pg_clipfrac = verl_F.masked_mean( + torch.gt(pg_losses2, pg_losses1).float(), response_mask + ) pg_losses3 = -advantages * clip_ratio_c clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) @@ -789,13 +819,22 @@ def compute_policy_loss( ) pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode + ) return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower @register_policy_loss("gpg") -def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None): +def compute_policy_loss_gpg( + old_log_prob, + log_prob, + advantages, + response_mask, + loss_agg_mode="token-mean", + config=None, +): """Adapted from https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495 Args: @@ -811,7 +850,9 @@ def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, l """ pg_losses = -log_prob * advantages - pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode + ) return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) @@ -855,12 +896,28 @@ def compute_policy_loss_clip_cov( clip_cov_ub (float, optional): Upper bound for clipping covariance. Defaults to 5.0. """ - clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002 + clip_cov_ratio = ( + config.policy_loss.clip_cov_ratio + if config.policy_loss.clip_cov_ratio is not None + else 0.0002 + ) cliprange = config.clip_ratio - cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange - cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange - clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0 - clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0 + cliprange_low = ( + config.clip_ratio_low if config.clip_ratio_low is not None else cliprange + ) + cliprange_high = ( + config.clip_ratio_high if config.clip_ratio_high is not None else cliprange + ) + clip_cov_ub = ( + config.policy_loss.clip_cov_ub + if config.policy_loss.clip_cov_ub is not None + else 5.0 + ) + clip_cov_lb = ( + config.policy_loss.clip_cov_lb + if config.policy_loss.clip_cov_lb is not None + else 1.0 + ) assert clip_cov_ratio > 0, "clip_ratio should be larger than 0." @@ -900,7 +957,9 @@ def compute_policy_loss_clip_cov( pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask) pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr - pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode + ) return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0) @@ -936,8 +995,16 @@ def compute_policy_loss_kl_cov( ppo_kl_coef (float, optional): Coefficient for the KL penalty term in the loss. Defaults to 1. """ - kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002 - ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0 + kl_cov_ratio = ( + config.policy_loss.kl_cov_ratio + if config.policy_loss.kl_cov_ratio is not None + else 0.0002 + ) + ppo_kl_coef = ( + config.policy_loss.ppo_kl_coef + if config.policy_loss.ppo_kl_coef is not None + else 1.0 + ) assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0." @@ -957,17 +1024,25 @@ def compute_policy_loss_kl_cov( k = min(kl_cov_ratio, len(all_valid_adv)) if k != 0: - cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean()) + cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * ( + all_valid_logp - all_valid_logp.mean() + ) k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio)) large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices if len(large_cov_idxs) != 0: large_cov_idxs = all_valid_idx[large_cov_idxs] - pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[ - large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1] + pg_losses[ + large_cov_idxs // advantages.shape[1], + large_cov_idxs % advantages.shape[1], + ] = pg_losses_kl[ + large_cov_idxs // advantages.shape[1], + large_cov_idxs % advantages.shape[1], ] - pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + pg_loss = agg_loss( + loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode + ) return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0) @@ -985,7 +1060,9 @@ def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean """ # compute entropy token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) - entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + entropy_loss = agg_loss( + loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode + ) return entropy_loss @@ -1022,16 +1099,24 @@ def compute_value_loss( vf_clipfrac (float): Fraction of elements where the clipped loss was used. """ - vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) + vpredclipped = verl_F.clip_by_value( + vpreds, values - cliprange_value, values + cliprange_value + ) vf_losses1 = (vpreds - returns) ** 2 vf_losses2 = (vpredclipped - returns) ** 2 clipped_vf_losses = torch.max(vf_losses1, vf_losses2) - vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) + vf_loss = 0.5 * agg_loss( + loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode + ) + vf_clipfrac = verl_F.masked_mean( + torch.gt(vf_losses2, vf_losses1).float(), response_mask + ) return vf_loss, vf_clipfrac -def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: +def kl_penalty( + logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty +) -> torch.FloatTensor: """Compute KL divergence given logprob and ref_logprob. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 See more description in http://joschu.net/blog/kl-approx.html @@ -1086,7 +1171,9 @@ def compute_pf_ppo_reweight_data( """ @torch.no_grad() - def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor: + def compute_weights( + scores: torch.Tensor, reweight_method: str, weight_pow: float + ) -> torch.Tensor: """Compute importance weights for resampling based on scores. Args: @@ -1105,7 +1192,9 @@ def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: floa elif reweight_method == "max_min": max_score = torch.max(scores) min_score = torch.min(scores) - weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0) + weights = torch.where( + (scores == max_score) | (scores == min_score), 1.0, 0.0 + ) elif reweight_method == "max_random": max_score = torch.max(scores) weights = torch.where(scores == max_score, 0.4, 0.1) @@ -1120,7 +1209,9 @@ def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: floa batch_size = scores.shape[0] sample_indices = torch.multinomial(weights, batch_size, replacement=True) - resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()} + resampled_batch = { + key: tensor[sample_indices] for key, tensor in data.batch.items() + } sample_indices_np = sample_indices.numpy() resampled_non_tensor_batch = {} diff --git a/Agent0/executor_train/verl/verl/trainer/ppo/metric_utils.py b/Agent0/executor_train/verl/verl/trainer/ppo/metric_utils.py index 3b6b47b..341e035 100644 --- a/Agent0/executor_train/verl/verl/trainer/ppo/metric_utils.py +++ b/Agent0/executor_train/verl/verl/trainer/ppo/metric_utils.py @@ -151,7 +151,9 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, "critic/values/max": torch.max(valid_values).detach().item(), "critic/values/min": torch.min(valid_values).detach().item(), # vf explained var - "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)) + .detach() + .item(), } if use_critic else {} @@ -160,14 +162,20 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, "response_length/mean": torch.mean(response_length).detach().item(), "response_length/max": torch.max(response_length).detach().item(), "response_length/min": torch.min(response_length).detach().item(), - "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + "response_length/clip_ratio": torch.mean( + torch.eq(response_length, max_response_length).float() + ) .detach() .item(), # prompt length "prompt_length/mean": torch.mean(prompt_length).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), "prompt_length/min": torch.min(prompt_length).detach().item(), - "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + "prompt_length/clip_ratio": torch.mean( + torch.eq(prompt_length, max_prompt_length).float() + ) + .detach() + .item(), } # multi-turn conversation @@ -180,7 +188,9 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, return metrics -def compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> dict[str, Any]: +def compute_timing_metrics( + batch: DataProto, timing_raw: dict[str, float] +) -> dict[str, Any]: """ Computes timing metrics for different processing stages in PPO training. @@ -210,19 +220,26 @@ def compute_timing_metrics(batch: DataProto, timing_raw: dict[str, float]) -> di num_tokens_of_section = { "gen": num_response_tokens, - **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, + **{ + name: num_overall_tokens + for name in ["ref", "values", "adv", "update_critic", "update_actor"] + }, } return { **{f"timing_s/{name}": value for name, value in timing_raw.items()}, **{ - f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + f"timing_per_token_ms/{name}": timing_raw[name] + * 1000 + / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) }, } -def compute_throughout_metrics(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]: +def compute_throughout_metrics( + batch: DataProto, timing_raw: dict[str, float], n_gpus: int +) -> dict[str, Any]: """ Computes throughput metrics for PPO training. @@ -336,7 +353,10 @@ def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> flo def process_validation_metrics( - data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 + data_sources: list[str], + sample_inputs: list[str], + infos_dict: dict[str, list[Any]], + seed: int = 42, ) -> dict[str, dict[str, dict[str, float]]]: """ Process validation metrics into a structured format with statistical analysis. @@ -380,7 +400,9 @@ def process_validation_metrics( >>> # result will contain statistics for each data source and variable """ # Group metrics by data source, prompt and variable - data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + data_src2prompt2var2vals = defaultdict( + lambda: defaultdict(lambda: defaultdict(list)) + ) for sample_idx, data_source in enumerate(data_sources): prompt = sample_inputs[sample_idx] var2vals = data_src2prompt2var2vals[data_source][prompt] @@ -388,7 +410,9 @@ def process_validation_metrics( var2vals[var_name].append(var_vals[sample_idx]) # Calculate metrics for each group - data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + data_src2prompt2var2metric = defaultdict( + lambda: defaultdict(lambda: defaultdict(dict)) + ) for data_source, prompt2var2vals in data_src2prompt2var2vals.items(): for prompt, var2vals in prompt2var2vals.items(): for var_name, var_vals in var2vals.items(): @@ -411,36 +435,63 @@ def process_validation_metrics( for n in ns: [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric( - data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed + data=var_vals, + subset_size=n, + reduce_fns=[np.max, np.min], + seed=seed, + ) + metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = ( + bon_mean, + bon_std, + ) + metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = ( + won_mean, + won_std, ) - metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std - metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std if var2vals.get("pred", None) is not None: vote_data = [ - {"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True) + {"val": val, "pred": pred} + for val, pred in zip( + var_vals, var2vals["pred"], strict=True + ) ] [(maj_n_mean, maj_n_std)] = bootstrap_metric( data=vote_data, subset_size=n, - reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")], + reduce_fns=[ + partial( + calc_maj_val, vote_key="pred", val_key="val" + ) + ], seed=seed, ) - metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std + metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = ( + maj_n_mean, + maj_n_std, + ) data_src2prompt2var2metric[data_source][prompt][var_name] = metric # Aggregate metrics across prompts - data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + data_src2var2metric2prompt_vals = defaultdict( + lambda: defaultdict(lambda: defaultdict(list)) + ) for data_source, prompt2var2metric in data_src2prompt2var2metric.items(): for prompt, var2metric in prompt2var2metric.items(): for var_name, metric in var2metric.items(): for metric_name, metric_val in metric.items(): - data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val) + data_src2var2metric2prompt_vals[data_source][var_name][ + metric_name + ].append(metric_val) - data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float))) + data_src2var2metric2val = defaultdict( + lambda: defaultdict(lambda: defaultdict(float)) + ) for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items(): for var_name, metric2prompt_vals in var2metric2prompt_vals.items(): for metric_name, prompt_vals in metric2prompt_vals.items(): - data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals) + data_src2var2metric2val[data_source][var_name][metric_name] = np.mean( + prompt_vals + ) return data_src2var2metric2val diff --git a/Agent0/executor_train/verl/verl/trainer/ppo/ray_trainer.py b/Agent0/executor_train/verl/verl/trainer/ppo/ray_trainer.py index 5ba32ac..9427875 100644 --- a/Agent0/executor_train/verl/verl/trainer/ppo/ray_trainer.py +++ b/Agent0/executor_train/verl/verl/trainer/ppo/ray_trainer.py @@ -40,7 +40,11 @@ from verl.experimental.dataset.sampler import AbstractCurriculumSampler from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto from verl.single_controller.base import Worker -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray import ( + RayClassWithInitArgs, + RayResourcePool, + RayWorkerGroup, +) from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.config import AlgoConfig from verl.trainer.ppo import core_algos @@ -52,12 +56,18 @@ process_validation_metrics, ) from verl.trainer.ppo.reward import compute_reward, compute_reward_async -from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi +from verl.utils.checkpoint.checkpoint_manager import ( + find_latest_ckpt_path, + should_save_ckpt_esi, +) from verl.utils.debug import marked_timer from verl.utils.metric import ( reduce_metrics, ) -from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.seqlen_balancing import ( + get_seqlen_balanced_partitions, + log_seqlen_unbalance, +) from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger @@ -102,7 +112,10 @@ def create_resource_pool(self): # For Megatron backend, we recommend using max_colocate_count>1 # that can utilize different WorkerGroup for differnt models resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + process_on_nodes=process_on_nodes, + use_gpu=True, + max_colocate_count=1, + name_prefix=resource_pool_name, ) self.resource_pool_dict[resource_pool_name] = resource_pool @@ -114,20 +127,34 @@ def get_resource_pool(self, role: Role) -> RayResourcePool: def get_n_gpus(self) -> int: """Get the number of gpus in this cluster.""" - return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + return sum( + [ + n_gpus + for process_on_nodes in self.resource_pool_spec.values() + for n_gpus in process_on_nodes + ] + ) def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" node_available_resources = ray.state.available_resources_per_node() node_available_gpus = { - node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) + node: ( + node_info.get("GPU", 0) + if "GPU" in node_info + else node_info.get("NPU", 0) + ) for node, node_info in node_available_resources.items() } # check total required gpus can be satisfied total_available_gpus = sum(node_available_gpus.values()) total_required_gpus = sum( - [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] + [ + n_gpus + for process_on_nodes in self.resource_pool_spec.values() + for n_gpus in process_on_nodes + ] ) if total_available_gpus < total_required_gpus: raise ValueError( @@ -150,7 +177,9 @@ def _check_resource_available(self): ) -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): +def apply_kl_penalty( + data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl" +): """Apply KL penalty to the token-level rewards. This function computes the KL divergence between the reference policy and current policy, @@ -188,7 +217,10 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) data.batch["token_level_rewards"] = token_level_rewards - metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} + metrics = { + "actor/reward_kl_penalty": current_kl, + "actor/reward_kl_penalty_coeff": beta, + } return data, metrics @@ -352,7 +384,9 @@ def __init__( assert self.hybrid_engine, "Currently, only support hybrid engine" if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + assert ( + Role.ActorRollout in role_worker_mapping + ), f"{role_worker_mapping.keys()=}" self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager @@ -368,7 +402,9 @@ def __init__( # define in-reward KL control # kl loss control currently not suppoorted if self.config.algorithm.use_kl_in_reward: - self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) + self.kl_ctrl_in_reward = core_algos.get_kl_controller( + self.config.algorithm.kl_ctrl + ) if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: self.use_critic = True @@ -399,20 +435,31 @@ def _validate_config(self): * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size ) assert ( - n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0 + n_gpus + % ( + model_parallel_size + * config.actor_rollout_ref.actor.megatron.context_parallel_size + ) + == 0 ), ( f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times " f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})" ) megatron_dp = n_gpus // ( - model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size + model_parallel_size + * config.actor_rollout_ref.actor.megatron.context_parallel_size + ) + minimal_bsz = ( + megatron_dp + * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu ) - minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu else: minimal_bsz = n_gpus # 1. Check total batch size for data correctness - real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n + real_train_batch_size = ( + config.data.train_batch_size * config.actor_rollout_ref.rollout.n + ) assert real_train_batch_size % minimal_bsz == 0, ( f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size " f"({minimal_bsz})" @@ -483,13 +530,17 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): if self.use_critic and not config.critic.use_dynamic_bsz: # Check for critic micro-batch size conflicts check_mutually_exclusive( - config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic" + config.critic.ppo_micro_batch_size, + config.critic.ppo_micro_batch_size_per_gpu, + "critic", ) # Check for reward model micro-batch size conflicts if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: check_mutually_exclusive( - config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" + config.reward_model.micro_batch_size, + config.reward_model.micro_batch_size_per_gpu, + "reward_model", ) # Actor @@ -498,15 +549,23 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): # ppo_mini_batch_size is divisible by ppo_micro_batch_size # ppo_micro_batch_size * sequence_parallel_size >= n_gpus if not config.actor_rollout_ref.actor.use_dynamic_bsz: - assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size - sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) + assert ( + config.data.train_batch_size + >= config.actor_rollout_ref.actor.ppo_mini_batch_size + ) + sp_size = config.actor_rollout_ref.actor.get( + "ulysses_sequence_parallel_size", 1 + ) if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: assert ( config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 ) - assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus + assert ( + config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size + >= n_gpus + ) assert config.actor_rollout_ref.actor.loss_agg_mode in [ "token-mean", @@ -515,7 +574,10 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): "seq-mean-token-sum-norm", ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" - if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + if ( + self.config.algorithm.use_kl_in_reward + and config.actor_rollout_ref.actor.use_kl_loss + ): print("NOTICE: You have both enabled in-reward kl and kl loss.") # critic @@ -523,7 +585,11 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) if config.critic.ppo_micro_batch_size is not None: - assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 + assert ( + config.critic.ppo_mini_batch_size + % config.critic.ppo_micro_batch_size + == 0 + ) assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus # Check if use_remove_padding is enabled when using sequence parallelism for fsdp @@ -531,15 +597,15 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 ): - assert config.actor_rollout_ref.model.use_remove_padding, ( - "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." - ) + assert ( + config.actor_rollout_ref.model.use_remove_padding + ), "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}: if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: - assert config.critic.model.use_remove_padding, ( - "When using sequence parallelism for critic, you must enable `use_remove_padding`." - ) + assert ( + config.critic.model.use_remove_padding + ), "When using sequence parallelism for critic, you must enable `use_remove_padding`." if config.data.get("val_batch_size", None) is not None: print( @@ -550,15 +616,16 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): # check eval config if config.actor_rollout_ref.rollout.val_kwargs.do_sample: - assert config.actor_rollout_ref.rollout.temperature > 0, ( - "validation gen temperature should be greater than 0 when enabling do_sample" - ) + assert ( + config.actor_rollout_ref.rollout.temperature > 0 + ), "validation gen temperature should be greater than 0 when enabling do_sample" # check multi_turn with tool config if config.actor_rollout_ref.rollout.multi_turn.enable: assert ( config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None - or config.actor_rollout_ref.rollout.multi_turn.interaction_config_path is not None + or config.actor_rollout_ref.rollout.multi_turn.interaction_config_path + is not None ), ( "tool_config_path or interaction_config_path must be set when enabling multi_turn with tool, " "due to no role-playing support" @@ -566,7 +633,9 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): print("[validate_config] All configuration checks passed successfully!") - def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + def _create_dataloader( + self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler] + ): """ Creates the train and validation dataloaders. """ @@ -575,11 +644,17 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl if train_dataset is None: train_dataset = create_rl_dataset( - self.config.data.train_files, self.config.data, self.tokenizer, self.processor + self.config.data.train_files, + self.config.data, + self.tokenizer, + self.processor, ) if val_dataset is None: val_dataset = create_rl_dataset( - self.config.data.val_files, self.config.data, self.tokenizer, self.processor + self.config.data.val_files, + self.config.data, + self.tokenizer, + self.processor, ) self.train_dataset, self.val_dataset = train_dataset, val_dataset @@ -594,7 +669,9 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl self.train_dataloader = StatefulDataLoader( dataset=self.train_dataset, - batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + batch_size=self.config.data.get( + "gen_batch_size", self.config.data.train_batch_size + ), num_workers=num_workers, drop_last=True, collate_fn=collate_fn, @@ -622,7 +699,9 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl f"{len(self.val_dataloader)}" ) - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + total_training_steps = ( + len(self.train_dataloader) * self.config.trainer.total_epochs + ) if self.config.trainer.total_training_steps is not None: total_training_steps = self.config.trainer.total_training_steps @@ -634,13 +713,19 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl OmegaConf.set_struct(self.config, True) with open_dict(self.config): if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.actor_rollout_ref.actor.optim.total_training_steps = ( + total_training_steps + ) if OmegaConf.select(self.config, "critic.optim"): self.config.critic.optim.total_training_steps = total_training_steps except Exception as e: - print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + print( + f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}" + ) - def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path): + def _dump_generations( + self, inputs, outputs, scores, reward_extra_infos_dict, dump_path + ): """Dump rollout/validation samples as JSONL.""" os.makedirs(dump_path, exist_ok=True) filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") @@ -689,7 +774,9 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): samples = samples[:generations_to_log] # Log to each configured logger - self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + self.validation_generations_logger.log( + self.config.trainer.logger, samples, self.global_steps + ) def _validate(self): data_source_lst = [] @@ -706,17 +793,24 @@ def _validate(self): # repeat test batch test_batch = test_batch.repeat( - repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, + interleave=True, ) # we only do validation on rule-based rm - if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + if ( + self.config.reward_model.enable + and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model" + ): return {} # Store original inputs input_ids = test_batch.batch["input_ids"] # TODO: Can we keep special tokens except for padding tokens? - input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + input_texts = [ + self.tokenizer.decode(ids, skip_special_tokens=True) + for ids in input_ids + ] sample_inputs.extend(input_texts) batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] @@ -751,20 +845,31 @@ def _validate(self): if not self.async_rollout_mode else self.config.actor_rollout_ref.rollout.agent.num_workers ) - test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor( + test_gen_batch, size_divisor + ) if not self.async_rollout_mode: - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences( + test_gen_batch_padded + ) else: - test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + test_output_gen_batch_padded = ( + self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + ) # unpad - test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + test_output_gen_batch = unpad_dataproto( + test_output_gen_batch_padded, pad_size=pad_size + ) print("validation generation end") # Store generated outputs output_ids = test_output_gen_batch.batch["responses"] - output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + output_texts = [ + self.tokenizer.decode(ids, skip_special_tokens=True) + for ids in output_ids + ] sample_outputs.extend(output_texts) test_batch = test_batch.union(test_output_gen_batch) @@ -777,19 +882,29 @@ def _validate(self): sample_scores.extend(scores) reward_extra_infos_dict["reward"].extend(scores) - print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") + print( + f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}" + ) if "reward_extra_info" in result: for key, lst in result["reward_extra_info"].items(): reward_extra_infos_dict[key].extend(lst) - print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") + print( + f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}" + ) # collect num_turns of each prompt if "__num_turns__" in test_batch.non_tensor_batch: sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) - data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + data_source_lst.append( + test_batch.non_tensor_batch.get( + "data_source", ["unknown"] * reward_tensor.shape[0] + ) + ) - self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + self._maybe_log_val_generations( + inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores + ) # dump generations val_data_dir = self.config.trainer.get("validation_data_dir", None) @@ -803,20 +918,32 @@ def _validate(self): ) for key_info, lst in reward_extra_infos_dict.items(): - assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + assert len(lst) == 0 or len(lst) == len( + sample_scores + ), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" data_sources = np.concatenate(data_source_lst, axis=0) - data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) + data_src2var2metric2val = process_validation_metrics( + data_sources, sample_inputs, reward_extra_infos_dict + ) metric_dict = {} for data_source, var2metric2val in data_src2var2metric2val.items(): core_var = "acc" if "acc" in var2metric2val else "reward" for var_name, metric2val in var2metric2val.items(): - n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + n_max = max( + [ + int(name.split("@")[-1].split("/")[0]) + for name in metric2val.keys() + ] + ) for metric_name, metric_val in metric2val.items(): if ( (var_name == core_var) - and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and any( + metric_name.startswith(pfx) + for pfx in ["mean", "maj", "best"] + ) and (f"@{n_max}" in metric_name) ): metric_sec = "val-core" @@ -842,25 +969,33 @@ def init_workers(self): """ self.resource_pool_manager.create_resource_pool() - self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + self.resource_pool_to_cls = { + pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values() + } # create actor and rollout if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + resource_pool = self.resource_pool_manager.get_resource_pool( + Role.ActorRollout + ) actor_rollout_cls = RayClassWithInitArgs( cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.actor_rollout_ref, role="actor_rollout", profile_option=self.config.trainer.npu_profile.options, ) - self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + self.resource_pool_to_cls[resource_pool][ + "actor_rollout" + ] = actor_rollout_cls else: raise NotImplementedError # create critic if self.use_critic: resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) + critic_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.Critic], config=self.config.critic + ) self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls # create reference policy if needed @@ -877,8 +1012,13 @@ def init_workers(self): # create a reward model if reward_fn is None if self.use_rm: # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + resource_pool = self.resource_pool_manager.get_resource_pool( + Role.RewardModel + ) + rm_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RewardModel], + config=self.config.reward_model, + ) self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls # initialize WorkerGroup @@ -888,13 +1028,21 @@ def init_workers(self): # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. all_wg = {} wg_kwargs = {} # Setting up kwargs for RayWorkerGroup - if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: - wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if ( + OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") + is not None + ): + wg_kwargs["ray_wait_register_center_timeout"] = ( + self.config.trainer.ray_wait_register_center_timeout + ) if OmegaConf.select(self.config.trainer, "profile_steps") is not None: - wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps") - assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, ( - "worker_nsight_options must be set when profile_steps is set" + wg_kwargs["profile_steps"] = OmegaConf.select( + self.config.trainer, "profile_steps" ) + assert ( + OmegaConf.select(self.config.trainer, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when profile_steps is set" wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( OmegaConf.select(self.config.trainer, "worker_nsight_options") ) @@ -951,24 +1099,37 @@ def _save_checkpoint(self): actor_remote_path = ( None if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + else os.path.join( + self.config.trainer.default_hdfs_dir, + f"global_step_{self.global_steps}", + "actor", + ) ) - remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + remove_previous_ckpt_in_save = self.config.trainer.get( + "remove_previous_ckpt_in_save", False + ) if remove_previous_ckpt_in_save: print( "Warning: remove_previous_ckpt_in_save is deprecated," + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" ) max_actor_ckpt_to_keep = ( - self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + self.config.trainer.get("max_actor_ckpt_to_keep", None) + if not remove_previous_ckpt_in_save + else 1 ) max_critic_ckpt_to_keep = ( - self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + self.config.trainer.get("max_critic_ckpt_to_keep", None) + if not remove_previous_ckpt_in_save + else 1 ) self.actor_rollout_wg.save_checkpoint( - actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + actor_local_path, + actor_remote_path, + self.global_steps, + max_ckpt_to_keep=max_actor_ckpt_to_keep, ) if self.use_critic: @@ -976,10 +1137,17 @@ def _save_checkpoint(self): critic_remote_path = ( None if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + else os.path.join( + self.config.trainer.default_hdfs_dir, + f"global_step_{self.global_steps}", + "critic", + ) ) self.critic_wg.save_checkpoint( - critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + critic_local_path, + critic_remote_path, + self.global_steps, + max_ckpt_to_keep=max_critic_ckpt_to_keep, ) # save dataloader @@ -1003,11 +1171,15 @@ def _load_checkpoint(self): if self.config.trainer.default_hdfs_dir is not None: raise NotImplementedError("load from hdfs is not implemented yet") else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + checkpoint_folder = ( + self.config.trainer.default_local_dir + ) # TODO: check path if not os.path.isabs(checkpoint_folder): working_dir = os.getcwd() checkpoint_folder = os.path.join(working_dir, checkpoint_folder) - global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + global_step_folder = find_latest_ckpt_path( + checkpoint_folder + ) # None if no latest # find global_step_folder if self.config.trainer.resume_mode == "auto": @@ -1016,10 +1188,12 @@ def _load_checkpoint(self): return 0 else: if self.config.trainer.resume_mode == "resume_path": - assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert "global_step_" in self.config.trainer.resume_from_path, ( - "resume ckpt must specify the global_steps" - ) + assert isinstance( + self.config.trainer.resume_from_path, str + ), "resume ckpt must be str type" + assert ( + "global_step_" in self.config.trainer.resume_from_path + ), "resume ckpt must specify the global_steps" global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() @@ -1035,37 +1209,49 @@ def _load_checkpoint(self): critic_path = os.path.join(global_step_folder, "critic") # load actor self.actor_rollout_wg.load_checkpoint( - actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + actor_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load, ) # load critic if self.use_critic: self.critic_wg.load_checkpoint( - critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + critic_path, + del_local_after_load=self.config.trainer.del_local_ckpt_after_load, ) # load dataloader, # TODO: from remote not implemented yet dataloader_local_path = os.path.join(global_step_folder, "data.pt") if os.path.exists(dataloader_local_path): - dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + dataloader_state_dict = torch.load( + dataloader_local_path, weights_only=False + ) self.train_dataloader.load_state_dict(dataloader_state_dict) else: - print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + print( + f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch" + ) def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): """Reorder the data on single controller such that each dp rank gets similar total tokens""" attention_mask = batch.batch["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = ( + batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() + ) # (train_batch_size,) world_size = self.actor_rollout_wg.world_size global_partition_lst = get_seqlen_balanced_partitions( global_seqlen_lst, k_partitions=world_size, equal_size=True ) # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + global_idx = torch.tensor( + [j for partition in global_partition_lst for j in partition] + ) batch.reorder(global_idx) global_balance_stats = log_seqlen_unbalance( - seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + seqlen_list=global_seqlen_lst, + partitions=global_partition_lst, + prefix=logging_prefix, ) metrics.update(global_balance_stats) @@ -1094,7 +1280,9 @@ def fit(self): # perform validation before training # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + if self.val_reward_fn is not None and self.config.trainer.get( + "val_before_train", True + ): val_metrics = self._validate() assert val_metrics, f"{val_metrics=}" pprint(f"Initial validation metrics: {val_metrics}") @@ -1103,7 +1291,11 @@ def fit(self): return # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + progress_bar = tqdm( + total=self.total_training_steps, + initial=self.global_steps, + desc="Training Progress", + ) # we start from step 1 self.global_steps += 1 @@ -1122,7 +1314,9 @@ def fit(self): ) with marked_timer("start_profile", timing_raw): if do_profile: - self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + self.actor_rollout_wg.start_profile( + role="e2e", profile_step=self.global_steps + ) if self.use_reference_policy: self.ref_policy_wg.start_profile() if self.use_critic: @@ -1155,7 +1349,10 @@ def fit(self): # pass global_steps to trace gen_batch.meta_info["global_steps"] = self.global_steps - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + gen_batch = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) is_last_step = self.global_steps >= self.total_training_steps @@ -1163,9 +1360,13 @@ def fit(self): # generate a batch with marked_timer("gen", timing_raw, color="red"): if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + gen_batch_output = self.actor_rollout_wg.generate_sequences( + gen_batch + ) else: - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + gen_batch_output = ( + self.async_rollout_manager.generate_sequences(gen_batch) + ) timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) @@ -1173,7 +1374,11 @@ def fit(self): with marked_timer("gen_max", timing_raw, color="purple"): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + gen_baseline_output = ( + self.actor_rollout_wg.generate_sequences( + gen_baseline_batch + ) + ) batch = batch.union(gen_baseline_output) reward_baseline_tensor = self.reward_fn(batch) @@ -1186,10 +1391,14 @@ def fit(self): del gen_baseline_batch, gen_baseline_output batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + [str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object, ) # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, + interleave=True, + ) batch = batch.union(gen_batch_output) if "response_mask" not in batch.batch.keys(): @@ -1203,7 +1412,9 @@ def fit(self): self._balance_batch(batch, metrics=metrics) # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["global_token_num"] = torch.sum( + batch.batch["attention_mask"], dim=-1 + ).tolist() with marked_timer("reward", timing_raw, color="yellow"): # compute reward model score @@ -1212,18 +1423,30 @@ def fit(self): batch = batch.union(reward_tensor) if self.config.reward_model.launch_reward_fn_async: - future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer) + future_reward = compute_reward_async.remote( + batch, self.config, self.tokenizer + ) else: - reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + reward_tensor, reward_extra_infos_dict = compute_reward( + batch, self.reward_fn + ) # recompute old_log_probs with marked_timer("old_log_prob", timing_raw, color="blue"): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) entropys = old_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + loss_agg_mode = ( + self.config.actor_rollout_ref.actor.loss_agg_mode + ) + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=loss_agg_mode, + ) + old_log_prob_metrics = { + "actor/entropy": entropy_agg.detach().item() + } metrics.update(old_log_prob_metrics) old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) @@ -1240,7 +1463,9 @@ def fit(self): rollout_probs = torch.exp(rollout_old_log_probs) actor_probs = torch.exp(actor_old_log_probs) rollout_probs_diff = torch.abs(rollout_probs - actor_probs) - rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool()) + rollout_probs_diff = torch.masked_select( + rollout_probs_diff, response_mask.bool() + ) rollout_probs_diff_max = torch.max(rollout_probs_diff) rollout_probs_diff_mean = torch.mean(rollout_probs_diff) rollout_probs_diff_std = torch.std(rollout_probs_diff) @@ -1256,9 +1481,13 @@ def fit(self): # compute reference log_prob with marked_timer("ref", timing_raw, color="olive"): if not self.ref_in_actor: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob( + batch + ) else: - ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + ref_log_prob = ( + self.actor_rollout_wg.compute_ref_log_prob(batch) + ) batch = batch.union(ref_log_prob) # compute values @@ -1271,20 +1500,31 @@ def fit(self): # we combine with rule-based rm reward_extra_infos_dict: dict[str, list] if self.config.reward_model.launch_reward_fn_async: - reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + reward_tensor, reward_extra_infos_dict = ray.get( + future_reward + ) batch.batch["token_level_scores"] = reward_tensor if reward_extra_infos_dict: - batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + batch.non_tensor_batch.update( + { + k: np.array(v) + for k, v in reward_extra_infos_dict.items() + } + ) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: batch, kl_metrics = apply_kl_penalty( - batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + batch, + kl_ctrl=self.kl_ctrl_in_reward, + kl_penalty=self.config.algorithm.kl_penalty, ) metrics.update(kl_metrics) else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + batch.batch["token_level_rewards"] = batch.batch[ + "token_level_scores" + ] # compute advantages, executed on the driver process @@ -1306,26 +1546,40 @@ def fit(self): if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + critic_output_metrics = reduce_metrics( + critic_output.meta_info["metrics"] + ) metrics.update(critic_output_metrics) # implement critic warmup if self.config.trainer.critic_warmup <= self.global_steps: # update actor with marked_timer("update_actor", timing_raw, color="red"): - batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + batch.meta_info["multi_turn"] = ( + self.config.actor_rollout_ref.rollout.multi_turn.enable + ) actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + actor_output_metrics = reduce_metrics( + actor_output.meta_info["metrics"] + ) metrics.update(actor_output_metrics) # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: - with marked_timer("dump_rollout_generations", timing_raw, color="green"): + with marked_timer( + "dump_rollout_generations", timing_raw, color="green" + ): print(batch.batch.keys()) - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) - scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + inputs = self.tokenizer.batch_decode( + batch.batch["prompts"], skip_special_tokens=True + ) + outputs = self.tokenizer.batch_decode( + batch.batch["responses"], skip_special_tokens=True + ) + scores = ( + batch.batch["token_level_scores"].sum(-1).cpu().tolist() + ) self._dump_generations( inputs=inputs, outputs=outputs, @@ -1338,7 +1592,10 @@ def fit(self): if ( self.val_reward_fn is not None and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + and ( + is_last_step + or self.global_steps % self.config.trainer.test_freq == 0 + ) ): with marked_timer("testing", timing_raw, color="green"): val_metrics: dict = self._validate() @@ -1364,7 +1621,9 @@ def fit(self): or esi_close_to_expiration ): if esi_close_to_expiration: - print("Force saving checkpoint: ESI instance expiration approaching.") + print( + "Force saving checkpoint: ESI instance expiration approaching." + ) with marked_timer("save_checkpoint", timing_raw, color="green"): self._save_checkpoint() @@ -1389,11 +1648,19 @@ def fit(self): } ) # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + metrics.update( + compute_data_metrics(batch=batch, use_critic=self.use_critic) + ) + metrics.update( + compute_timing_metrics(batch=batch, timing_raw=timing_raw) + ) # TODO: implement actual tflpo and theoretical tflpo n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + metrics.update( + compute_throughout_metrics( + batch=batch, timing_raw=timing_raw, n_gpus=n_gpus + ) + ) # this is experimental and may be changed/removed in the future in favor of a general-purpose one if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): diff --git a/Agent0/executor_train/verl/verl/trainer/ppo/reward.py b/Agent0/executor_train/verl/verl/trainer/ppo/reward.py index 143b631..41ff7b7 100644 --- a/Agent0/executor_train/verl/verl/trainer/ppo/reward.py +++ b/Agent0/executor_train/verl/verl/trainer/ppo/reward.py @@ -71,7 +71,9 @@ def get_custom_reward_fn(config): function_name = reward_fn_config.get("name") if not hasattr(module, function_name): - raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") + raise AttributeError( + f"Reward function '{function_name}' not found in '{file_path}'." + ) print(f"using customized reward function '{function_name}' from '{file_path}'") raw_fn = getattr(module, function_name) @@ -118,7 +120,9 @@ def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): if sandbox_url: sandbox_manager = multiprocessing.Manager() # Create a semaphore to control concurrent access to the sandbox - _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) + _concurrent_semaphore = sandbox_manager.Semaphore( + sandbox_config.get("max_concurrent", 64) + ) final_compute_score = partial( default_compute_score, sandbox_fusion_url=sandbox_url, @@ -165,5 +169,7 @@ def compute_reward_async(data: DataProto, config, tokenizer): Load the reward manager and compute the reward for a batch of data. This is meant to be run in a separate Ray worker. """ - reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) return compute_reward(data, reward_fn) diff --git a/Agent0/executor_train/verl/verl/utils/__init__.py b/Agent0/executor_train/verl/verl/utils/__init__.py index 0345849..fc9d632 100644 --- a/Agent0/executor_train/verl/verl/utils/__init__.py +++ b/Agent0/executor_train/verl/verl/utils/__init__.py @@ -16,4 +16,8 @@ from .config import omega_conf_to_dataclass from .tokenizer import hf_processor, hf_tokenizer -__all__ = tokenizer.__all__ + config.__all__ + ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass"] +__all__ = ( + tokenizer.__all__ + + config.__all__ + + ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass"] +) diff --git a/Agent0/executor_train/verl/verl/utils/activation_offload.py b/Agent0/executor_train/verl/verl/utils/activation_offload.py index 73e2e83..3db774f 100644 --- a/Agent0/executor_train/verl/verl/utils/activation_offload.py +++ b/Agent0/executor_train/verl/verl/utils/activation_offload.py @@ -72,18 +72,24 @@ def __init__( def __enter__(self): self.inside_context = True - torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor) + torch._C._autograd._push_saved_tensors_default_hooks( + self.on_save_for_backward, self.on_get_saved_tensor + ) def __exit__(self, *args: Any): self.inside_context = False torch._C._autograd._pop_saved_tensors_default_hooks() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + retrieve_identifier = self.offload_handler.tensor_push( + tensor, **self.handler_extra_kwargs + ) return retrieve_identifier def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + tensor = self.offload_handler.tensor_pop( + saved_state, **self.handler_extra_kwargs + ) return tensor @@ -140,7 +146,9 @@ class SynchronizedGroupOffloadHandler(OffloadHandler): as the computation kernels, thus the copying will block computation. """ - def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None: + def __init__( + self, num_offload_group, tensor_need_offloading_checker=(lambda _: True) + ) -> None: super().__init__() self.num_offload_group = num_offload_group @@ -198,7 +206,10 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs): tensor_tag = (self.current_group, self.tensor_count_current_group) self.tensor_count_current_group += 1 assert tensor_tag not in self.tensor_tag_to_state - if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + if ( + self.current_group < self.num_offload_group + and self.tensor_need_offloading_checker(tensor) + ): state = SynchronizedGroupOffloadHandler.offload(tensor) self.tensor_tag_to_state[tensor_tag] = state else: @@ -249,7 +260,9 @@ def __init__( # for optimal CPU/GPU interconnect usage constant = 0 for i in range(self.num_offload_group): - self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + self.layer_window_map[i] = ( + (self.num_layers // self.num_offload_group) * (i + 1) + ) - 1 if i < (self.num_layers % self.num_offload_group): self.layer_window_map[i] += i + 1 constant = i + 1 @@ -263,7 +276,8 @@ def __init__( def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: torch_stray_tensor = isinstance( tensor, - torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor, + torch._subclasses.fake_tensor.FakeTensor + | torch._subclasses.functional_tensor.FunctionalTensor, ) need_offload = not torch_stray_tensor need_offload = need_offload and self.tensor_need_offloading_checker(tensor) @@ -396,7 +410,9 @@ def on_group_commit_backward(self): def get_activation_offload_context( - num_layers: int = 1, model_layers: int = 1, tensor_need_offloading_checker=(lambda t: True) + num_layers: int = 1, + model_layers: int = 1, + tensor_need_offloading_checker=(lambda t: True), ): cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( num_offload_group=num_layers, @@ -444,7 +460,9 @@ def _pack_kwargs(self, *args, **kwargs): return tuple(flat_args), tuple(kwarg_keys) def _unpack_kwargs(self, flat_args, kwarg_keys): - assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + assert len(kwarg_keys) <= len( + flat_args + ), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" if len(kwarg_keys) == 0: return flat_args, {} args = flat_args[: -len(kwarg_keys)] @@ -518,7 +536,9 @@ def enable_activation_offloading(model, strategy, enable_ckpt=False): """ - assert strategy == "fsdp" or strategy == "fsdp2", "activation offloading only supports fsdp strategy" + assert ( + strategy == "fsdp" or strategy == "fsdp2" + ), "activation offloading only supports fsdp strategy" layers = [] def get_layers(module): @@ -536,11 +556,15 @@ def get_layers(module): get_layers(model) if len(layers) < 3: - logger.warning(f"Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading") + logger.warning( + f"Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading" + ) return tensor_filter = FSDPParameterFilter() - context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter) + context, sync_func = get_activation_offload_context( + len(layers) - 1, len(layers), tensor_filter + ) if enable_ckpt: # The implementation of activation checkpointing in transformers library is incompatible with # activation offloading, diff --git a/Agent0/executor_train/verl/verl/utils/checkpoint/checkpoint_manager.py b/Agent0/executor_train/verl/verl/utils/checkpoint/checkpoint_manager.py index ff861ab..8fb3a31 100644 --- a/Agent0/executor_train/verl/verl/utils/checkpoint/checkpoint_manager.py +++ b/Agent0/executor_train/verl/verl/utils/checkpoint/checkpoint_manager.py @@ -49,8 +49,12 @@ def __init__( checkpoint_config: DictConfig = None, ): self.checkpoint_config = checkpoint_config - checkpoint_load_contents = checkpoint_config.get("load_contents", None) if checkpoint_config else None - checkpoint_save_contents = checkpoint_config.get("save_contents", None) if checkpoint_config else None + checkpoint_load_contents = ( + checkpoint_config.get("load_contents", None) if checkpoint_config else None + ) + checkpoint_save_contents = ( + checkpoint_config.get("save_contents", None) if checkpoint_config else None + ) if checkpoint_load_contents is None: checkpoint_load_contents = ["model", "optimizer", "extra"] if checkpoint_save_contents is None: @@ -118,18 +122,28 @@ def should_load_extra(self) -> bool: """ return "extra" in self.checkpoint_load_contents - def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False): + def load_checkpoint( + self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False + ): raise NotImplementedError def save_checkpoint( - self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None + self, + local_path: str, + hdfs_path: str = None, + global_step: int = 0, + max_ckpt_to_keep: int = None, ): raise NotImplementedError @staticmethod def checkpath(local_path: str, hdfs_path: str): - assert local_path is not None or hdfs_path is not None, "local_path and hdfs_path cannot be both None" - return local_path is not None, local_path if local_path is not None else hdfs_path + assert ( + local_path is not None or hdfs_path is not None + ), "local_path and hdfs_path cannot be both None" + return local_path is not None, ( + local_path if local_path is not None else hdfs_path + ) def remove_previous_save_local_path(self, path): if isinstance(path, str): @@ -203,7 +217,9 @@ def get_checkpoint_tracker_filename(root_path: str): return os.path.join(root_path, "latest_checkpointed_iteration.txt") -def should_save_ckpt_esi(max_steps_duration: float, save_ckpt_duration: float = 60, redundant_time: float = 0) -> bool: +def should_save_ckpt_esi( + max_steps_duration: float, save_ckpt_duration: float = 60, redundant_time: float = 0 +) -> bool: """ Determine if checkpoint should be saved based on capacity esi expiration. @@ -213,7 +229,9 @@ def should_save_ckpt_esi(max_steps_duration: float, save_ckpt_duration: float = redundant_time: Additional buffer time (seconds) for unexpected delays (default: 0) """ exp_ts_mlp = os.getenv("MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP") # vemlp - exp_ts_aws = os.getenv("SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP") # aws + exp_ts_aws = os.getenv( + "SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP" + ) # aws if exp_ts_mlp: try: import time @@ -231,7 +249,9 @@ def should_save_ckpt_esi(max_steps_duration: float, save_ckpt_duration: float = expiration_time = datetime.fromtimestamp(int(exp_ts_aws)) time_difference = expiration_time - datetime.now() - threshold_minutes = (save_ckpt_duration + max_steps_duration + redundant_time) / 60 + threshold_minutes = ( + save_ckpt_duration + max_steps_duration + redundant_time + ) / 60 return time_difference < timedelta(minutes=threshold_minutes) else: return False diff --git a/Agent0/executor_train/verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/Agent0/executor_train/verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py index e042ae8..73c5ad7 100644 --- a/Agent0/executor_train/verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/Agent0/executor_train/verl/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -24,12 +24,20 @@ from accelerate import init_empty_weights from omegaconf import DictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType +from torch.distributed.fsdp import ( + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictType, +) from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin from verl.utils.device import is_cuda_available from verl.utils.fs import copy_to_local, is_non_local, local_mkdir_safe -from verl.utils.fsdp_utils import fsdp_version, get_fsdp_full_state_dict, get_fsdp_state_ctx +from verl.utils.fsdp_utils import ( + fsdp_version, + get_fsdp_full_state_dict, + get_fsdp_state_ctx, +) from verl.utils.logger import log_with_rank from .checkpoint_manager import BaseCheckpointManager @@ -80,7 +88,9 @@ def __init__( if processing_class is None: assert "tokenizer" in kwargs, "tokenizer or processor must be provided" warnings.warn( - "`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2 + "`tokenizer` is deprecated. use `processing_class` instead.", + DeprecationWarning, + stacklevel=2, ) processing_class = kwargs.pop("tokenizer") @@ -92,7 +102,9 @@ def __init__( checkpoint_config=checkpoint_config, ) - def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): + def load_checkpoint( + self, local_path: str, hdfs_path: str = None, del_local_after_load=False + ): """ Load an FSDP checkpoint for this rank. @@ -110,11 +122,13 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte # check if the checkpoint_load_contents is valid if self.should_load_model: - assert self.model is not None, "model must be provided when checkpoint_contents.load includes ['model']" + assert ( + self.model is not None + ), "model must be provided when checkpoint_contents.load includes ['model']" if self.should_load_optimizer: - assert self.optimizer is not None, ( - "optimizer must be provided when checkpoint_contents.load includes ['optimizer']" - ) + assert ( + self.optimizer is not None + ), "optimizer must be provided when checkpoint_contents.load includes ['optimizer']" # every rank download its own checkpoint state_dict_cfg = ( @@ -123,28 +137,47 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte else None ) optim_cfg = ( - ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + ShardedOptimStateDictConfig( + offload_to_cpu=True if is_cuda_available else False + ) if self.should_load_optimizer else None ) - with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): + with get_fsdp_state_ctx( + self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg + ): if self.should_load_model: - remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") + remote_model_path = os.path.join( + local_path, + f"model_world_size_{self.world_size}_rank_{self.rank}.pt", + ) local_model_path = copy_to_local(remote_model_path) model_state_dict = torch.load(local_model_path, weights_only=False) self.model.load_state_dict(model_state_dict) - log_with_rank(f"Loaded model from {remote_model_path}", rank=self.rank, logger=logger) + log_with_rank( + f"Loaded model from {remote_model_path}", + rank=self.rank, + logger=logger, + ) if self.should_load_optimizer: - remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") + remote_optim_path = os.path.join( + local_path, + f"optim_world_size_{self.world_size}_rank_{self.rank}.pt", + ) local_optim_path = copy_to_local(remote_optim_path) optimizer_state_dict = torch.load(local_optim_path, weights_only=False) self.optimizer.load_state_dict(optimizer_state_dict) - log_with_rank(f"Loaded optimizer from {remote_optim_path}", rank=self.rank, logger=logger) + log_with_rank( + f"Loaded optimizer from {remote_optim_path}", + rank=self.rank, + logger=logger, + ) if self.should_load_extra: remote_extra_state_path = os.path.join( - local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" + local_path, + f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt", ) local_extra_state_path = copy_to_local(remote_extra_state_path) extra_state_dict = torch.load(local_extra_state_path, weights_only=False) @@ -152,18 +185,30 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte if "rng" in extra_state_dict: # 'rng' may not exist for backward compatibility self.load_rng_state(extra_state_dict["rng"]) - log_with_rank(f"Loaded rng from {remote_extra_state_path}", rank=self.rank, logger=logger) + log_with_rank( + f"Loaded rng from {remote_extra_state_path}", + rank=self.rank, + logger=logger, + ) lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] if lr_scheduler_state_dict is not None and self.lr_scheduler is not None: self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) - log_with_rank(f"Loaded lr_scheduler from {remote_extra_state_path}", rank=self.rank, logger=logger) + log_with_rank( + f"Loaded lr_scheduler from {remote_extra_state_path}", + rank=self.rank, + logger=logger, + ) if self.rank == 0 and del_local_after_load: try: os.remove(local_model_path) if is_non_local(local_model_path) else None os.remove(local_optim_path) if is_non_local(local_optim_path) else None - os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None + ( + os.remove(local_extra_state_path) + if is_non_local(local_extra_state_path) + else None + ) except Exception as e: log_with_rank( f"remove local resume ckpt file after loading failed, exception {e} will be ignored", @@ -174,7 +219,13 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte # wait for everyone to load checkpoints torch.distributed.barrier() - def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): + def save_checkpoint( + self, + local_path: str, + hdfs_path: str = None, + global_step: int = 0, + max_ckpt_to_keep=None, + ): """ Save an FSDP checkpoint for this rank. @@ -215,40 +266,73 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i # check if the checkpoint_save_contents is valid if self.should_save_model: - assert self.model is not None, "model must be provided when checkpoint_contents.save includes ['model']" + assert ( + self.model is not None + ), "model must be provided when checkpoint_contents.save includes ['model']" if self.should_save_optimizer: - assert self.optimizer is not None, ( - "optimizer must be provided when checkpoint_contents.save includes ['optimizer']" - ) + assert ( + self.optimizer is not None + ), "optimizer must be provided when checkpoint_contents.save includes ['optimizer']" # every rank will save its own model and optim shard - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + state_dict_cfg = ShardedStateDictConfig( + offload_to_cpu=True if is_cuda_available else False + ) + optim_cfg = ShardedOptimStateDictConfig( + offload_to_cpu=True if is_cuda_available else False + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") - with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): - model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") - optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") - extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") + with get_fsdp_state_ctx( + self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg + ): + model_path = os.path.join( + local_path, + f"model_world_size_{self.world_size}_rank_{self.rank}.pt", + ) + optim_path = os.path.join( + local_path, + f"optim_world_size_{self.world_size}_rank_{self.rank}.pt", + ) + extra_path = os.path.join( + local_path, + f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt", + ) if self.should_save_model: model_state_dict = self.model.state_dict() torch.save(model_state_dict, model_path) - log_with_rank(f"Saved model to {os.path.abspath(model_path)}", rank=self.rank, logger=logger) + log_with_rank( + f"Saved model to {os.path.abspath(model_path)}", + rank=self.rank, + logger=logger, + ) if self.should_save_optimizer: optimizer_state_dict = self.optimizer.state_dict() torch.save(optimizer_state_dict, optim_path) - log_with_rank(f"Saved optim to {os.path.abspath(optim_path)}", rank=self.rank, logger=logger) + log_with_rank( + f"Saved optim to {os.path.abspath(optim_path)}", + rank=self.rank, + logger=logger, + ) if self.should_save_extra: - lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None + lr_scheduler_state_dict = ( + self.lr_scheduler.state_dict() + if self.lr_scheduler is not None + else None + ) extra_state_dict = { "lr_scheduler": lr_scheduler_state_dict, "rng": self.get_rng_state(), } torch.save(extra_state_dict, extra_path) - log_with_rank(f"Saved extra_state to {os.path.abspath(extra_path)}", rank=self.rank, logger=logger) + log_with_rank( + f"Saved extra_state to {os.path.abspath(extra_path)}", + rank=self.rank, + logger=logger, + ) if self.rank == 0: # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether @@ -262,10 +346,16 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i hf_config_tokenizer_path = os.path.join(local_path, "huggingface") local_mkdir_safe(hf_config_tokenizer_path) model_config = unwrap_model.config - if unwrap_model.can_generate() and hasattr(model_config, "name_or_path") and model_config.name_or_path: + if ( + unwrap_model.can_generate() + and hasattr(model_config, "name_or_path") + and model_config.name_or_path + ): # Some model's name_or_path is empty if not initialized from pretrained, # in this cases, we don't save generation config. - generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) + generation_config = GenerationConfig.from_pretrained( + model_config.name_or_path + ) generation_config.save_pretrained(hf_config_tokenizer_path) else: generation_config = None @@ -294,7 +384,9 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i if self.should_save_hf_model: # Only rank 0 will save hf model and, # offload to cpu to save LLMs which may be too large to fit in one GPU - state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True) + state_dict = get_fsdp_full_state_dict( + self.model, offload_to_cpu=True, rank0_only=True + ) if self.rank == 0: hf_local_path = os.path.join(local_path, "huggingface") @@ -313,10 +405,14 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i auto_model_cls = AutoModelForVision2Seq else: - raise NotImplementedError(f"Unknown architecture {model_config['architectures']}") + raise NotImplementedError( + f"Unknown architecture {model_config['architectures']}" + ) with init_empty_weights(): - save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16) + save_model = auto_model_cls.from_config( + model_config, torch_dtype=torch.bfloat16 + ) save_model.to_empty(device="cpu") if save_model.can_generate(): diff --git a/Agent0/executor_train/verl/verl/utils/checkpoint/megatron_checkpoint_manager.py b/Agent0/executor_train/verl/verl/utils/checkpoint/megatron_checkpoint_manager.py index f0071b8..6135386 100644 --- a/Agent0/executor_train/verl/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ b/Agent0/executor_train/verl/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -31,7 +31,10 @@ from verl.utils.device import get_device_name, get_torch_device from verl.utils.fs import is_non_local, local_mkdir_safe from verl.utils.logger import log_with_rank -from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing, save_dist_checkpointing +from verl.utils.megatron.dist_checkpointing import ( + load_dist_checkpointing, + save_dist_checkpointing, +) from verl.utils.megatron_utils import ( get_dist_checkpoint_path, get_hf_model_checkpoint_path, @@ -143,12 +146,16 @@ def __init__( self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler self.bridge = bridge self.rank = torch.distributed.get_rank() - self.use_dist_checkpointing = use_dist_checkpointing or not self.bridge or self.is_value_model + self.use_dist_checkpointing = ( + use_dist_checkpointing or not self.bridge or self.is_value_model + ) self.use_hf_checkpoint = not self.use_dist_checkpointing self.weight_saver = get_weight_saver(self.arch) - def get_rng_state(self, use_dist_ckpt: bool = True, data_parallel_random_init: bool = False): + def get_rng_state( + self, use_dist_ckpt: bool = True, data_parallel_random_init: bool = False + ): """collect rng state across data parallel ranks""" rng_state = { "random_rng_state": random.getstate(), @@ -158,12 +165,20 @@ def get_rng_state(self, use_dist_ckpt: bool = True, data_parallel_random_init: b } if get_device_name() != "cpu": - rng_state[f"{get_device_name()}_rng_state"] = get_torch_device().get_rng_state() + rng_state[f"{get_device_name()}_rng_state"] = ( + get_torch_device().get_rng_state() + ) rng_state_list = None - if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init: + if ( + torch.distributed.is_initialized() + and mpu.get_data_parallel_world_size() > 1 + and data_parallel_random_init + ): rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())] - torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group()) + torch.distributed.all_gather_object( + rng_state_list, rng_state, group=mpu.get_data_parallel_group() + ) else: rng_state_list = [rng_state] @@ -217,7 +232,9 @@ def get_checkpoint_name( if not pipeline_parallel: common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}") else: - common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}") + common_path = os.path.join( + checkpoints_path, f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}" + ) if expert_parallel: common_path = common_path + f"_{expert_rank:03d}" @@ -263,7 +280,9 @@ def generate_state_dict(self): return state_dict - def load_rng_states(self, rng_states, data_parallel_random_init=False, use_dist_ckpt=True): + def load_rng_states( + self, rng_states, data_parallel_random_init=False, use_dist_ckpt=True + ): # access rng_state for data parallel rank if data_parallel_random_init: rng_states = rng_states[mpu.get_data_parallel_rank()] @@ -274,26 +293,42 @@ def load_rng_states(self, rng_states, data_parallel_random_init=False, use_dist_ torch.set_rng_state(rng_states["torch_rng_state"]) if get_device_name() != "cpu": - get_torch_device().set_rng_state(rng_states[f"{get_device_name()}_rng_state"]) + get_torch_device().set_rng_state( + rng_states[f"{get_device_name()}_rng_state"] + ) # Check for empty states array if not rng_states["rng_tracker_states"]: raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states(rng_states["rng_tracker_states"]) + tensor_parallel.get_cuda_rng_tracker().set_states( + rng_states["rng_tracker_states"] + ) - def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): + def load_checkpoint( + self, local_path: str, hdfs_path: str = None, del_local_after_load=False + ): if local_path is not None: - assert os.path.exists(local_path), f"Checkpoint path {local_path} does not exist." + assert os.path.exists( + local_path + ), f"Checkpoint path {local_path} does not exist." dist_checkpoint_path = get_dist_checkpoint_path(local_path) # Get State Dict for loading sharded_state_dict = self.generate_state_dict() - log_with_rank(f"Generated state dict for saving: {sharded_state_dict.keys()}", rank=self.rank, logger=logger) + log_with_rank( + f"Generated state dict for saving: {sharded_state_dict.keys()}", + rank=self.rank, + logger=logger, + ) for vpp_rank, model in enumerate(self.model): if len(self.model) > 1: model_i_keys = sharded_state_dict[f"model{vpp_rank}"].keys() - log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger) + log_with_rank( + f"Generated state dict for saving: {model_i_keys}", + rank=self.rank, + logger=logger, + ) else: log_with_rank( f"Generated state dict for saving: {sharded_state_dict['model'].keys()}", @@ -315,23 +350,37 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte if len(self.model) == 1: model_state_dict = state_dict["model"] else: - assert f"model{vpp_rank}" in state_dict, f"model{vpp_rank} not found in state_dict" + assert ( + f"model{vpp_rank}" in state_dict + ), f"model{vpp_rank} not found in state_dict" model_state_dict = state_dict[f"model{vpp_rank}"] mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) self.model[vpp_rank].load_state_dict(model_state_dict) - log_with_rank(f"Loaded sharded model checkpoint from {local_path}", rank=self.rank, logger=logger) + log_with_rank( + f"Loaded sharded model checkpoint from {local_path}", + rank=self.rank, + logger=logger, + ) elif self.should_load_model and self.use_hf_checkpoint: hf_model_path = get_hf_model_checkpoint_path(local_path) self.bridge.load_weights(self.model, hf_model_path) - log_with_rank(f"Loaded HF model checkpoint from {hf_model_path} with bridge", rank=self.rank, logger=logger) + log_with_rank( + f"Loaded HF model checkpoint from {hf_model_path} with bridge", + rank=self.rank, + logger=logger, + ) if self.should_load_optimizer: - assert "optimizer" in state_dict, ( - f"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." - ) + assert ( + "optimizer" in state_dict + ), f"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." optimizer_state_dict = state_dict["optimizer"] self.optimizer.load_state_dict(optimizer_state_dict) - log_with_rank(f"Loaded optimizer checkpoint from {local_path}", rank=self.rank, logger=logger) + log_with_rank( + f"Loaded optimizer checkpoint from {local_path}", + rank=self.rank, + logger=logger, + ) if self.use_checkpoint_opt_param_scheduler: assert "lr_scheduler" in state_dict, ( f"LR scheduler state dict not found in {state_dict.keys()}. Please check the checkpoint file " @@ -340,15 +389,21 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte lr_scheduler_state_dict = state_dict["lr_scheduler"] if self.lr_scheduler is not None: self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) - log_with_rank(f"Loaded LR scheduler checkpoint from {local_path}", rank=self.rank, logger=logger) + log_with_rank( + f"Loaded LR scheduler checkpoint from {local_path}", + rank=self.rank, + logger=logger, + ) if self.should_load_extra: - assert "rng_state" in state_dict, ( - f"RNG state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." - ) + assert ( + "rng_state" in state_dict + ), f"RNG state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." rng_state = state_dict["rng_state"] self.load_rng_states(rng_state) - log_with_rank(f"Loaded RNG states from {local_path}", rank=self.rank, logger=logger) + log_with_rank( + f"Loaded RNG states from {local_path}", rank=self.rank, logger=logger + ) if del_local_after_load: try: @@ -360,7 +415,13 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte logger=logger, ) - def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): + def save_checkpoint( + self, + local_path: str, + hdfs_path: str = None, + global_step: int = 0, + max_ckpt_to_keep=None, + ): # record the previous global step self.previous_global_step = global_step @@ -381,14 +442,24 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i if self.use_dist_checkpointing: # Generate state dict for saving state_dict = self.generate_state_dict() - log_with_rank(f"Generated state dict for saving: {state_dict.keys()}", rank=self.rank, logger=logger) + log_with_rank( + f"Generated state dict for saving: {state_dict.keys()}", + rank=self.rank, + logger=logger, + ) for vpp_rank, model in enumerate(self.model): if len(self.model) > 1: model_i_keys = state_dict[f"model{vpp_rank}"].keys() - log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger) + log_with_rank( + f"Generated state dict for saving: {model_i_keys}", + rank=self.rank, + logger=logger, + ) else: log_with_rank( - f"Generated state dict for saving: {state_dict['model'].keys()}", rank=self.rank, logger=logger + f"Generated state dict for saving: {state_dict['model'].keys()}", + rank=self.rank, + logger=logger, ) # Start Async save if enabled async_save_request = save_dist_checkpointing( @@ -399,14 +470,26 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i # Synchronize all async save requests if not self.checkpoint_config.async_save: - assert async_save_request is None, "Async save request should be None when not using async save." + assert ( + async_save_request is None + ), "Async save request should be None when not using async save." torch.distributed.barrier() else: - assert self.use_hf_checkpoint, "use_hf_checkpoint should be True when not using dist checkpointing" - log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger) + assert ( + self.use_hf_checkpoint + ), "use_hf_checkpoint should be True when not using dist checkpointing" + log_with_rank( + f"Saving HF model checkpoint to {local_path} with bridge", + rank=self.rank, + logger=logger, + ) hf_ckpt_path = get_hf_model_checkpoint_path(local_path) self.bridge.save_weights(self.model, hf_ckpt_path) - log_with_rank(f"Saved bridge checkpoint to {hf_ckpt_path}", rank=self.rank, logger=logger) + log_with_rank( + f"Saved bridge checkpoint to {hf_ckpt_path}", + rank=self.rank, + logger=logger, + ) if self.should_save_model: # Only rank 0 saves the hf config and tokenizer to huggingface path @@ -417,9 +500,14 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i self.processing_class.save_pretrained(hf_config_tokenizer_path) # Save huggingface config self.hf_config.save_pretrained(hf_config_tokenizer_path) - if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path: + if ( + hasattr(self.hf_config, "name_or_path") + and self.hf_config.name_or_path + ): try: - generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path) + generation_config = GenerationConfig.from_pretrained( + self.hf_config.name_or_path + ) generation_config.save_pretrained(hf_config_tokenizer_path) except Exception: # if the generation config isn't available, we don't save it @@ -441,14 +529,18 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i pop_keys = [] for key, value in transformer_config_dict.items(): if type(value) in to_convert_types: - transformer_config_dict[key] = to_convert_types[type(value)](value) + transformer_config_dict[key] = to_convert_types[type(value)]( + value + ) if type(value) in ignore_types: pop_keys.append(key) if callable(value): pop_keys.append(key) for key in pop_keys: transformer_config_dict.pop(key) - transformer_config_path = get_transformer_config_checkpoint_path(local_path) + transformer_config_path = get_transformer_config_checkpoint_path( + local_path + ) with open(transformer_config_path, "w") as f: json.dump(transformer_config_dict, f, indent=2) @@ -481,7 +573,9 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i else: from transformers import AutoModelForCausalLM - model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto") + model = AutoModelForCausalLM.from_pretrained( + self.config.model.path, torch_dtype="auto" + ) model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) log_with_rank( f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", @@ -492,32 +586,52 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i if hdfs_path is not None: log_with_rank( - f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True + f"Uploading checkpoint to {hdfs_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, ) from verl.utils import hdfs_io hdfs_io.makedirs(hdfs_path, exist_ok=True) - hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) + hdfs_io.copy( + src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True + ) log_with_rank( - f"HDFS checkpoint uploaded to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True + f"HDFS checkpoint uploaded to {hdfs_path}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, ) def finalize_save_fn(): # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided log_with_rank( - f"Dist checkpointing save completed for {dist_checkpoint_path}", rank=self.rank, logger=logger + f"Dist checkpointing save completed for {dist_checkpoint_path}", + rank=self.rank, + logger=logger, ) if self.rank == 0: if hdfs_path is not None: - log_with_rank(f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger) + log_with_rank( + f"Uploading checkpoint to {hdfs_path}", + rank=self.rank, + logger=logger, + ) from verl.utils import hdfs_io hdfs_io.makedirs(hdfs_path, exist_ok=True) - hdfs_io.copy(src=dist_checkpoint_path, dst=hdfs_path, dirs_exist_ok=True) - hdfs_io.copy(src=hf_config_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True) + hdfs_io.copy( + src=dist_checkpoint_path, dst=hdfs_path, dirs_exist_ok=True + ) + hdfs_io.copy( + src=hf_config_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True + ) if self.checkpoint_config.async_save: - assert async_save_request is not None, "Async save request should not be None when using async save." + assert ( + async_save_request is not None + ), "Async save request should not be None when using async save." async_save_request.add_finalize_fn(finalize_save_fn) else: finalize_save_fn() diff --git a/Agent0/executor_train/verl/verl/utils/config.py b/Agent0/executor_train/verl/verl/utils/config.py index f1c301f..d481f6a 100644 --- a/Agent0/executor_train/verl/verl/utils/config.py +++ b/Agent0/executor_train/verl/verl/utils/config.py @@ -20,7 +20,9 @@ __all__ = ["omega_conf_to_dataclass"] -def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any: +def omega_conf_to_dataclass( + config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None +) -> Any: """ Convert an OmegaConf DictConfig to a dataclass. diff --git a/Agent0/executor_train/verl/verl/utils/dataset/multiturn_sft_dataset.py b/Agent0/executor_train/verl/verl/utils/dataset/multiturn_sft_dataset.py index e3eed0f..0508132 100644 --- a/Agent0/executor_train/verl/verl/utils/dataset/multiturn_sft_dataset.py +++ b/Agent0/executor_train/verl/verl/utils/dataset/multiturn_sft_dataset.py @@ -32,7 +32,9 @@ def convert_nested_value_to_list_recursive(data_item): if isinstance(data_item, dict): - return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()} + return { + k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items() + } elif isinstance(data_item, list): return [convert_nested_value_to_list_recursive(elem) for elem in data_item] elif isinstance(data_item, np.ndarray): @@ -57,7 +59,9 @@ def __init__(self, parquet_files: str | list[str], tokenizer, config=None): multiturn_config = config.get("multiturn", {}) self.messages_key = multiturn_config.get("messages_key", "messages") self.tools_key = multiturn_config.get("tools_key", "tools") - self.enable_thinking_key = multiturn_config.get("enable_thinking_key", "enable_thinking") + self.enable_thinking_key = multiturn_config.get( + "enable_thinking_key", "enable_thinking" + ) assert self.truncation in ["error", "left", "right"] if not isinstance(parquet_files, list): @@ -73,14 +77,19 @@ def __init__(self, parquet_files: str | list[str], tokenizer, config=None): def _download(self): for i, parquet_file in enumerate(self.parquet_files): - self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) + self.parquet_files[i] = copy_local_path_from_hdfs( + parquet_file, verbose=True + ) def _read_files_and_process(self): def series_to_item(ls): import numpy import pandas - while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1: + while ( + isinstance(ls, pandas.core.series.Series | numpy.ndarray) + and len(ls) == 1 + ): ls = ls[0] return ls @@ -95,7 +104,11 @@ def series_to_item(ls): # Extract tools list from dataframe if self.tools_key in self.dataframe.columns: - self.tools = self.dataframe[self.tools_key].apply(convert_nested_value_to_list_recursive).tolist() + self.tools = ( + self.dataframe[self.tools_key] + .apply(convert_nested_value_to_list_recursive) + .tolist() + ) else: self.tools = None # Extract enable_thinking list from dataframe @@ -138,12 +151,14 @@ def _process_message_tokens( tools=tools, ) if is_assistant: - prev_applied_text_w_generation_prompt = self.tokenizer.apply_chat_template( - messages[:start_idx], - tokenize=False, - add_generation_prompt=True, - enable_thinking=enable_thinking, - tools=tools, + prev_applied_text_w_generation_prompt = ( + self.tokenizer.apply_chat_template( + messages[:start_idx], + tokenize=False, + add_generation_prompt=True, + enable_thinking=enable_thinking, + tools=tools, + ) ) else: @@ -158,7 +173,9 @@ def _process_message_tokens( ) # Get tokens for the current message only if is_assistant: - generation_prompt_text = prev_applied_text_w_generation_prompt[len(prev_applied_text) :] + generation_prompt_text = prev_applied_text_w_generation_prompt[ + len(prev_applied_text) : + ] generation_prompt_tokens = self.tokenizer.encode( generation_prompt_text, add_special_tokens=False, @@ -228,7 +245,9 @@ def __getitem__(self, item): tokenizer = self.tokenizer messages = self.messages[item] tools = self.tools[item] if self.tools is not None else None - enable_thinking = self.enable_thinking[item] if self.enable_thinking is not None else None + enable_thinking = ( + self.enable_thinking[item] if self.enable_thinking is not None else None + ) if self.tools is not None: tools = json.loads(self.tools[item]) @@ -263,7 +282,12 @@ def __getitem__(self, item): if cur_messages["role"] == "assistant": # Process assistant message tokens, loss_mask, attention_mask = self._process_message_tokens( - messages, i, i + 1, is_assistant=True, enable_thinking=enable_thinking, tools=tools + messages, + i, + i + 1, + is_assistant=True, + enable_thinking=enable_thinking, + tools=tools, ) concat_tokens.extend(tokens) concat_loss_mask.extend(loss_mask) @@ -305,10 +329,22 @@ def __getitem__(self, item): sequence_length = input_ids.shape[0] if sequence_length < self.max_length: # Pad sequences - pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 - padded_input_ids = torch.full((self.max_length - sequence_length,), pad_token_id, dtype=input_ids.dtype) - padded_attention_mask = torch.zeros((self.max_length - sequence_length,), dtype=attention_mask.dtype) - padded_loss_mask = torch.zeros((self.max_length - sequence_length,), dtype=loss_mask.dtype) + pad_token_id = ( + self.tokenizer.pad_token_id + if self.tokenizer.pad_token_id is not None + else 0 + ) + padded_input_ids = torch.full( + (self.max_length - sequence_length,), + pad_token_id, + dtype=input_ids.dtype, + ) + padded_attention_mask = torch.zeros( + (self.max_length - sequence_length,), dtype=attention_mask.dtype + ) + padded_loss_mask = torch.zeros( + (self.max_length - sequence_length,), dtype=loss_mask.dtype + ) input_ids = torch.cat((input_ids, padded_input_ids)) attention_mask = torch.cat((attention_mask, padded_attention_mask)) @@ -323,7 +359,9 @@ def __getitem__(self, item): attention_mask = attention_mask[: self.max_length] loss_mask = loss_mask[: self.max_length] elif self.truncation == "error": - raise ValueError(f"{sequence_length=} is larger than {self.max_length=}") + raise ValueError( + f"{sequence_length=} is larger than {self.max_length=}" + ) else: raise ValueError(f"Unknown truncation method {self.truncation}") diff --git a/Agent0/executor_train/verl/verl/utils/dataset/rl_dataset.py b/Agent0/executor_train/verl/verl/utils/dataset/rl_dataset.py index e053a67..28ba050 100644 --- a/Agent0/executor_train/verl/verl/utils/dataset/rl_dataset.py +++ b/Agent0/executor_train/verl/verl/utils/dataset/rl_dataset.py @@ -98,7 +98,9 @@ def __init__( self.processor = processor self.config = config - self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) + self.cache_dir = os.path.expanduser( + config.get("cache_dir", "~/.cache/verl/rlhf") + ) self.prompt_key = config.get("prompt_key", "prompt") self.image_key = config.get("image_key", "images") self.video_key = config.get("video_key", "videos") @@ -108,7 +110,9 @@ def __init__( self.truncation = config.get("truncation", "error") self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) - self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) + self.num_workers = config.get( + "filter_overlong_prompts_workers", max(1, os.cpu_count() // 4) + ) self.num_workers = min(self.num_workers, os.cpu_count()) self.use_shm = config.get("use_shm", False) self.chat_template_func = config.get("chat_template_func", None) @@ -123,15 +127,21 @@ def __init__( def _download(self, use_origin_parquet=False): from verl.utils.fs import copy_to_local - data_files = self.data_files if not use_origin_parquet else self.original_data_files + data_files = ( + self.data_files if not use_origin_parquet else self.original_data_files + ) for i, parquet_file in enumerate(data_files): - self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm) + self.data_files[i] = copy_to_local( + src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm + ) def _read_files_and_tokenize(self): dataframes = [] for parquet_file in self.data_files: # read parquet files and cache - dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + dataframe = datasets.load_dataset("parquet", data_files=parquet_file)[ + "train" + ] dataframes.append(dataframe) self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) @@ -157,18 +167,30 @@ def doc2len(doc) -> int: messages, add_generation_prompt=True, tokenize=False ) images = ( - [process_image(image) for image in messages.pop(image_key)] if image_key in messages else None + [process_image(image) for image in messages.pop(image_key)] + if image_key in messages + else None ) videos = ( - [process_video(video) for video in messages.pop(video_key)] if video_key in messages else None + [process_video(video) for video in messages.pop(video_key)] + if video_key in messages + else None ) - return len(processor(text=[raw_prompt], images=images, videos=videos)["input_ids"][0]) + return len( + processor(text=[raw_prompt], images=images, videos=videos)[ + "input_ids" + ][0] + ) else: def doc2len(doc) -> int: - return len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) + return len( + tokenizer.apply_chat_template( + doc[prompt_key], add_generation_prompt=True + ) + ) dataframe = dataframe.filter( lambda doc: doc2len(doc) <= self.max_prompt_length, @@ -183,10 +205,14 @@ def resume_dataset_state(self): self.serialize_dataset = not hasattr(self, "original_data_files") # resume dataframe if not it's serialized in data.pt if not self.serialize_dataset: - self._download(use_origin_parquet=True) # download and resume from original parquet files + self._download( + use_origin_parquet=True + ) # download and resume from original parquet files self._read_files_and_tokenize() else: - print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") + print( + r"old dataloader ckpt file is used, please train from scratch for better ckpt performance" + ) def __len__(self): return len(self.dataframe) @@ -223,26 +249,40 @@ def __getitem__(self, item): if self.processor is not None: from verl.utils.dataset.vision_utils import process_image, process_video - raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + raw_prompt = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) multi_modal_data = {} images = None - if self.image_key in row_dict and row_dict.get(self.image_key, None) is not None: - images = [process_image(image) for image in row_dict.pop(self.image_key)] + if ( + self.image_key in row_dict + and row_dict.get(self.image_key, None) is not None + ): + images = [ + process_image(image) for image in row_dict.pop(self.image_key) + ] # due to the image key is "image" instead of "images" in vllm, we need to use "image" here # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205 multi_modal_data["image"] = images videos = None - if self.video_key in row_dict and row_dict.get(self.video_key, None) is not None: - videos = [process_video(video) for video in row_dict.pop(self.video_key)] + if ( + self.video_key in row_dict + and row_dict.get(self.video_key, None) is not None + ): + videos = [ + process_video(video) for video in row_dict.pop(self.video_key) + ] # due to the video key is "video" instead of "videos" in vllm, we need to use "video" here # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205 multi_modal_data["video"] = [video.numpy() for video in videos] - model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt") + model_inputs = self.processor( + text=[raw_prompt], images=images, videos=videos, return_tensors="pt" + ) input_ids = model_inputs.pop("input_ids") attention_mask = model_inputs.pop("attention_mask") @@ -262,8 +302,12 @@ def __getitem__(self, item): row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None) else: - raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False) + raw_prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + model_inputs = self.tokenizer( + raw_prompt, return_tensors="pt", add_special_tokens=False + ) input_ids = model_inputs.pop("input_ids") attention_mask = model_inputs.pop("attention_mask") @@ -276,7 +320,11 @@ def __getitem__(self, item): truncation=self.truncation, ) - if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__: + if ( + self.processor is not None + and "Qwen2VLImageProcessor" + in self.processor.image_processor.__class__.__name__ + ): from verl.models.transformers.qwen2_vl import get_rope_index position_ids = [ @@ -306,9 +354,13 @@ def __getitem__(self, item): elif self.truncation == "middle": left_half = self.max_prompt_length // 2 right_half = self.max_prompt_length - left_half - raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:] + raw_prompt_ids = ( + raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:] + ) elif self.truncation == "error": - raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") + raise RuntimeError( + f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}." + ) row_dict["raw_prompt_ids"] = raw_prompt_ids # encode prompts without chat template @@ -322,10 +374,18 @@ def __getitem__(self, item): # add index for each prompt index = row_dict.get("extra_info", {}).get("index", 0) tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {}) - interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {}) - need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs) + interaction_kwargs = row_dict.get("extra_info", {}).get( + "interaction_kwargs", {} + ) + need_tools_kwargs = row_dict.get("extra_info", {}).get( + "need_tools_kwargs", self.need_tools_kwargs + ) if need_tools_kwargs and not tools_kwargs: - logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"]) + logger.warning( + "tools_kwargs is empty for index {}, data source: {}", + index, + row_dict["data_source"], + ) row_dict["index"] = index row_dict["tools_kwargs"] = tools_kwargs row_dict["interaction_kwargs"] = interaction_kwargs diff --git a/Agent0/executor_train/verl/verl/utils/dataset/rm_dataset.py b/Agent0/executor_train/verl/verl/utils/dataset/rm_dataset.py index 7af7923..d48377c 100644 --- a/Agent0/executor_train/verl/verl/utils/dataset/rm_dataset.py +++ b/Agent0/executor_train/verl/verl/utils/dataset/rm_dataset.py @@ -100,10 +100,23 @@ def _pad_to_length(self, input_ids, attention_mask): if curr_length < self.max_length: input_ids = torch.cat( - (input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1 + ( + input_ids, + torch.zeros( + size=(self.max_length - curr_length,), dtype=input_ids.dtype + ), + ), + dim=-1, ) attention_mask = torch.cat( - (attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)), dim=-1 + ( + attention_mask, + torch.zeros( + size=(self.max_length - curr_length,), + dtype=attention_mask.dtype, + ), + ), + dim=-1, ) elif curr_length > self.max_length: input_ids = input_ids[: self.max_length] @@ -117,13 +130,21 @@ def __getitem__(self, item): rejected_response = self.rejected_responses[item] prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] - chosen_response_ids = self.tokenizer(chosen_response, return_tensors="pt")["input_ids"][0] - rejected_response_ids = self.tokenizer(rejected_response, return_tensors="pt")["input_ids"][0] + chosen_response_ids = self.tokenizer(chosen_response, return_tensors="pt")[ + "input_ids" + ][0] + rejected_response_ids = self.tokenizer(rejected_response, return_tensors="pt")[ + "input_ids" + ][0] if self.add_eos: - chosen_response_ids = torch.cat((chosen_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1) + chosen_response_ids = torch.cat( + (chosen_response_ids, torch.tensor([self.tokenizer.eos_token_id])), + dim=-1, + ) rejected_response_ids = torch.cat( - (rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1 + (rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])), + dim=-1, ) chosen_input_ids = torch.cat((prompt_ids, chosen_response_ids), dim=-1) @@ -132,11 +153,17 @@ def __getitem__(self, item): rejected_input_ids = torch.cat((prompt_ids, rejected_response_ids), dim=-1) rejected_attention_mask = torch.ones_like(rejected_input_ids) - chosen_input_ids, chosen_attention_mask = self._pad_to_length(chosen_input_ids, chosen_attention_mask) - rejected_input_ids, rejected_attention_mask = self._pad_to_length(rejected_input_ids, rejected_attention_mask) + chosen_input_ids, chosen_attention_mask = self._pad_to_length( + chosen_input_ids, chosen_attention_mask + ) + rejected_input_ids, rejected_attention_mask = self._pad_to_length( + rejected_input_ids, rejected_attention_mask + ) input_ids = torch.stack((chosen_input_ids, rejected_input_ids), dim=0) - attention_mask = torch.stack((chosen_attention_mask, rejected_attention_mask), dim=0) + attention_mask = torch.stack( + (chosen_attention_mask, rejected_attention_mask), dim=0 + ) return { "input_ids": input_ids, diff --git a/Agent0/executor_train/verl/verl/utils/dataset/sft_dataset.py b/Agent0/executor_train/verl/verl/utils/dataset/sft_dataset.py index 2aa7b20..8bde134 100644 --- a/Agent0/executor_train/verl/verl/utils/dataset/sft_dataset.py +++ b/Agent0/executor_train/verl/verl/utils/dataset/sft_dataset.py @@ -58,8 +58,12 @@ def __init__(self, parquet_files: str | ListConfig, tokenizer, config): tokenizer = hf_tokenizer(tokenizer) self.tokenizer: PreTrainedTokenizer = tokenizer - self.prompt_key = prompt_key if isinstance(prompt_key, tuple | list) else [prompt_key] - self.response_key = response_key if isinstance(response_key, tuple | list) else [response_key] + self.prompt_key = ( + prompt_key if isinstance(prompt_key, tuple | list) else [prompt_key] + ) + self.response_key = ( + response_key if isinstance(response_key, tuple | list) else [response_key] + ) self.prompt_dict_keys = prompt_dict_keys if prompt_dict_keys else [] self.response_dict_keys = response_dict_keys if response_dict_keys else [] @@ -70,14 +74,19 @@ def __init__(self, parquet_files: str | ListConfig, tokenizer, config): def _download(self): for i, parquet_file in enumerate(self.parquet_files): - self.parquet_files[i] = copy_to_local(parquet_file, verbose=True, use_shm=self.use_shm) + self.parquet_files[i] = copy_to_local( + parquet_file, verbose=True, use_shm=self.use_shm + ) def _read_files_and_tokenize(self): def series_to_item(ls): import numpy import pandas - while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1: + while ( + isinstance(ls, pandas.core.series.Series | numpy.ndarray) + and len(ls) == 1 + ): ls = ls[0] return ls @@ -93,7 +102,9 @@ def series_to_item(ls): # type(x[0]): numpy.ndarray # type(x[0][0]): dict try: - self.prompts = self.prompts.apply(lambda x: series_to_item(x)[key], axis=1) # noqa: B023 + self.prompts = self.prompts.apply( + lambda x: series_to_item(x)[key], axis=1 + ) # noqa: B023 except Exception: print(f"self.prompts={self.prompts}") raise @@ -103,7 +114,9 @@ def series_to_item(ls): self.responses = self.dataframe[self.response_key] for key in self.response_dict_keys: try: - self.responses = self.responses.apply(lambda x: series_to_item(x)[key], axis=1) # noqa: B023 + self.responses = self.responses.apply( + lambda x: series_to_item(x)[key], axis=1 + ) # noqa: B023 except Exception: print(f"self.responses={self.responses}") raise @@ -124,15 +137,21 @@ def __getitem__(self, item): prompt_chat = [{"role": "user", "content": prompt}] # string - prompt_chat_str = tokenizer.apply_chat_template(prompt_chat, add_generation_prompt=True, tokenize=False) + prompt_chat_str = tokenizer.apply_chat_template( + prompt_chat, add_generation_prompt=True, tokenize=False + ) response_chat_str = response + tokenizer.eos_token # tokenize - prompt_ids_output = tokenizer(prompt_chat_str, return_tensors="pt", add_special_tokens=False) + prompt_ids_output = tokenizer( + prompt_chat_str, return_tensors="pt", add_special_tokens=False + ) prompt_ids = prompt_ids_output["input_ids"][0] prompt_attention_mask = prompt_ids_output["attention_mask"][0] - response_ids_output = tokenizer(response_chat_str, return_tensors="pt", add_special_tokens=False) + response_ids_output = tokenizer( + response_chat_str, return_tensors="pt", add_special_tokens=False + ) response_ids = response_ids_output["input_ids"][0] response_attention_mask = response_ids_output["attention_mask"][0] @@ -140,16 +159,22 @@ def __getitem__(self, item): response_length = response_ids.shape[0] input_ids = torch.cat((prompt_ids, response_ids), dim=-1) - attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) + attention_mask = torch.cat( + (prompt_attention_mask, response_attention_mask), dim=-1 + ) # padding to max length sequence_length = input_ids.shape[0] if sequence_length < self.max_length: padded_input_ids = ( - torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) + torch.ones( + size=(self.max_length - sequence_length,), dtype=input_ids.dtype + ) * self.tokenizer.pad_token_id ) - padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype) + padded_attention_mask = torch.zeros( + size=(self.max_length - sequence_length,), dtype=attention_mask.dtype + ) input_ids = torch.cat((input_ids, padded_input_ids)) attention_mask = torch.cat((attention_mask, padded_attention_mask)) @@ -162,9 +187,13 @@ def __getitem__(self, item): input_ids = input_ids[: self.max_length] attention_mask = attention_mask[: self.max_length] elif self.truncation == "error": - raise NotImplementedError(f"{sequence_length=} is larger than {self.max_length=}") + raise NotImplementedError( + f"{sequence_length=} is larger than {self.max_length=}" + ) else: - raise NotImplementedError(f"Unknown truncation method {self.truncation}") + raise NotImplementedError( + f"Unknown truncation method {self.truncation}" + ) position_ids = compute_position_id_with_mask(attention_mask) diff --git a/Agent0/executor_train/verl/verl/utils/dataset/vision_utils.py b/Agent0/executor_train/verl/verl/utils/dataset/vision_utils.py index 75cce7f..6bd476f 100644 --- a/Agent0/executor_train/verl/verl/utils/dataset/vision_utils.py +++ b/Agent0/executor_train/verl/verl/utils/dataset/vision_utils.py @@ -92,14 +92,18 @@ def process_video( return fetch_video(video) -def process_multi_modal_inputs_for_minicpmo(input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs): +def process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs +): # Adjust image bounds based on left padding and cumulative sequence lengths # This is necessary for MiniCPM-o's vision-language alignment left_padding_length = torch.argmax(attention_mask, dim=1) image_bounds = [] for i in range(len(multi_modal_inputs["image_bound"])): image_bound = ( - multi_modal_inputs["image_bound"][i].to(left_padding_length.device) - left_padding_length[i] + cu_seqlens[i] + multi_modal_inputs["image_bound"][i].to(left_padding_length.device) + - left_padding_length[i] + + cu_seqlens[i] ) image_bounds.append(image_bound) diff --git a/Agent0/executor_train/verl/verl/utils/debug/trajectory_tracker.py b/Agent0/executor_train/verl/verl/utils/debug/trajectory_tracker.py index 73afb85..ea64cae 100644 --- a/Agent0/executor_train/verl/verl/utils/debug/trajectory_tracker.py +++ b/Agent0/executor_train/verl/verl/utils/debug/trajectory_tracker.py @@ -80,9 +80,9 @@ def get_trajectory_tracker(): hdfs_dir = os.getenv("VERL_TRACKER_HDFS_DIR", default=None) verbose = os.getenv("VERL_TRACKER_VERBOSE", default="0") == "1" assert hdfs_dir is not None - tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True, lifetime="detached").remote( - hdfs_dir, verbose - ) + tracker = TrajectoryTracker.options( + name="global_tracker", get_if_exists=True, lifetime="detached" + ).remote(hdfs_dir, verbose) return tracker diff --git a/Agent0/executor_train/verl/verl/utils/device.py b/Agent0/executor_train/verl/verl/utils/device.py index ed85b0d..a5fc19f 100644 --- a/Agent0/executor_train/verl/verl/utils/device.py +++ b/Agent0/executor_train/verl/verl/utils/device.py @@ -61,7 +61,9 @@ def get_torch_device() -> any: try: return getattr(torch, device_name) except AttributeError: - logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + logger.warning( + f"Device namespace '{device_name}' not found in torch, try to load torch.cuda." + ) return torch.cuda @@ -83,4 +85,6 @@ def get_nccl_backend() -> str: elif is_npu_available: return "hccl" else: - raise RuntimeError(f"No available nccl backend found on device type {get_device_name()}.") + raise RuntimeError( + f"No available nccl backend found on device type {get_device_name()}." + ) diff --git a/Agent0/executor_train/verl/verl/utils/experimental/torch_functional.py b/Agent0/executor_train/verl/verl/utils/experimental/torch_functional.py index 0b4ce5c..6fcc813 100644 --- a/Agent0/executor_train/verl/verl/utils/experimental/torch_functional.py +++ b/Agent0/executor_train/verl/verl/utils/experimental/torch_functional.py @@ -55,14 +55,18 @@ def _fused_linear_for_ppo_bwd( # Gradient from log_probs if dlog_probs is not None: - one_hot_input = torch.zeros_like(logits).scatter_(-1, input_ids.unsqueeze(-1), 1) + one_hot_input = torch.zeros_like(logits).scatter_( + -1, input_ids.unsqueeze(-1), 1 + ) dlogits += dlog_probs.to(torch.float32).unsqueeze(-1) * (one_hot_input - probs) # Gradient from entropy if dentropy is not None: log_probs = logits.log_softmax(dim=-1) entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1) - dlogits += probs * (log_probs + entropy.unsqueeze(-1)) * (-dentropy.unsqueeze(-1)) + dlogits += ( + probs * (log_probs + entropy.unsqueeze(-1)) * (-dentropy.unsqueeze(-1)) + ) dlogits = dlogits.to(orig_dtype) / temperature @@ -86,11 +90,16 @@ def forward( # Cast to a 2D tensor of the shape [T, D] for ease of working orig_ndim = hidden_states.ndim - assert orig_ndim in (2, 3), f"Invalid hidden_states shape, received {hidden_states.shape}" + assert orig_ndim in ( + 2, + 3, + ), f"Invalid hidden_states shape, received {hidden_states.shape}" orig_batch_size = -1 if orig_ndim == 3: - assert input_ids.ndim == 2, f"input_ids shape doesn't match, {hidden_states.shape} {input_ids.shape}" + assert ( + input_ids.ndim == 2 + ), f"input_ids shape doesn't match, {hidden_states.shape} {input_ids.shape}" orig_batch_size = hidden_states.shape[0] hidden_states = hidden_states.flatten(0, 1) input_ids = input_ids.flatten(0, 1) @@ -98,7 +107,9 @@ def forward( T = hidden_states.shape[0] # Allocate memory for outputs - output_requires_grad = hidden_states.requires_grad or vocab_weights.requires_grad + output_requires_grad = ( + hidden_states.requires_grad or vocab_weights.requires_grad + ) log_probs = hidden_states.new_zeros(T, requires_grad=output_requires_grad) entropy = hidden_states.new_zeros(T, requires_grad=output_requires_grad) @@ -129,7 +140,11 @@ def forward( return log_probs, entropy @staticmethod - def backward(ctx, dlog_probs: Optional[torch.FloatTensor], dentropy: Optional[torch.FloatTensor]): + def backward( + ctx, + dlog_probs: Optional[torch.FloatTensor], + dentropy: Optional[torch.FloatTensor], + ): assert dlog_probs is not None or dentropy is not None hidden_states, vocab_weights, input_ids = ctx.saved_tensors diff --git a/Agent0/executor_train/verl/verl/utils/flops_counter.py b/Agent0/executor_train/verl/verl/utils/flops_counter.py index 1bed929..a613f78 100644 --- a/Agent0/executor_train/verl/verl/utils/flops_counter.py +++ b/Agent0/executor_train/verl/verl/utils/flops_counter.py @@ -105,7 +105,11 @@ def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time): num_attention_heads = self.config.num_attention_heads intermediate_size = self.config.intermediate_size - head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) + head_dim = getattr( + self.config, + "head_dim", + self.config.hidden_size // self.config.num_attention_heads, + ) q_size = num_attention_heads * head_dim k_size = num_key_value_heads * head_dim v_size = num_key_value_heads * head_dim @@ -113,7 +117,9 @@ def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time): # non-attn per layer parm # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp mlp_N = hidden_size * intermediate_size * 3 - attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + attn_linear_N = hidden_size * ( + q_size + k_size + v_size + num_attention_heads * head_dim + ) emd_and_lm_head_N = vocab_size * hidden_size * 2 # non-attn all_layer parm dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N @@ -124,7 +130,9 @@ def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time): seqlen_square_sum = 0 for seqlen in batch_seqlens: seqlen_square_sum += seqlen * seqlen - attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + attn_qkv_flops = ( + 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + ) # all_layer & all_token fwd & bwd flops flops_all_token = dense_N_flops + attn_qkv_flops @@ -146,7 +154,9 @@ def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time): # non-attn per layer parm moe_gata_N = hidden_size * moe_num_expert # moe has fc1_1, fc1_2 and fc2 using SwiGLU in ExpertMlp layer & shared experts - moe_expertmlp_N = hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3 + moe_expertmlp_N = ( + hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3 + ) # MLA attn attn_linear_N = 0 q_head_dim = self.config.qk_nope_head_dim + self.config.qk_rope_head_dim @@ -156,7 +166,9 @@ def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time): attn_linear_N += hidden_size * self.config.q_lora_rank attn_linear_N += num_query_heads * q_head_dim * self.config.q_lora_rank - attn_linear_N += hidden_size * (self.config.kv_lora_rank + self.config.qk_rope_head_dim) + attn_linear_N += hidden_size * ( + self.config.kv_lora_rank + self.config.qk_rope_head_dim + ) attn_linear_N += ( num_query_heads * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim) @@ -166,8 +178,10 @@ def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time): emd_and_lm_head_N = vocab_size * hidden_size * 2 # non-attn all_layer parm moe_N = ( - (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace) - + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace + (moe_gata_N + moe_expertmlp_N + attn_linear_N) + * (num_hidden_layers - first_k_dense_replace) + + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) + * first_k_dense_replace + emd_and_lm_head_N ) # non-attn all_layer & all_token fwd & bwd flops @@ -195,15 +209,24 @@ def _estimate_qwen2_moe_flops(self, tokens_sum, batch_seqlens, delta_time): moe_topk = self.config.num_experts_per_tok num_experts = self.config.num_experts - head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) + head_dim = getattr( + self.config, + "head_dim", + self.config.hidden_size // self.config.num_attention_heads, + ) q_size = num_attention_heads * head_dim k_size = num_key_value_heads * head_dim v_size = num_key_value_heads * head_dim # non-attn per layer parm # gate + moe export - moe_mlp_N = hidden_size * moe_topk * moe_intermediate_size * 3 + hidden_size * num_experts - attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) + moe_mlp_N = ( + hidden_size * moe_topk * moe_intermediate_size * 3 + + hidden_size * num_experts + ) + attn_linear_N = hidden_size * ( + q_size + k_size + v_size + num_attention_heads * head_dim + ) emd_and_lm_head_N = vocab_size * hidden_size * 2 # non-attn all_layer parm dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N @@ -214,7 +237,9 @@ def _estimate_qwen2_moe_flops(self, tokens_sum, batch_seqlens, delta_time): seqlen_square_sum = 0 for seqlen in batch_seqlens: seqlen_square_sum += seqlen * seqlen - attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + attn_qkv_flops = ( + 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers + ) # all_layer & all_token fwd & bwd flops flops_all_token = dense_N_flops + attn_qkv_flops @@ -235,7 +260,9 @@ def estimate_flops(self, batch_seqlens, delta_time): promised_flops (float): The expected FLOPS of the current device. """ tokens_sum = sum(batch_seqlens) - func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops) + func = self.estimate_func.get( + self.config.model_type, self._estimate_unknown_flops + ) estimated_flops = func(tokens_sum, batch_seqlens, delta_time) promised_flops = get_device_flops() return estimated_flops, promised_flops diff --git a/Agent0/executor_train/verl/verl/utils/fs.py b/Agent0/executor_train/verl/verl/utils/fs.py index 7cc1130..d246024 100644 --- a/Agent0/executor_train/verl/verl/utils/fs.py +++ b/Agent0/executor_train/verl/verl/utils/fs.py @@ -144,7 +144,9 @@ def copy_to_shm(src: str): """ shm_model_root = "/dev/shm/verl-cache/" src_abs = os.path.abspath(os.path.normpath(src)) - dest = os.path.join(shm_model_root, hashlib.md5(src_abs.encode("utf-8")).hexdigest()) + dest = os.path.join( + shm_model_root, hashlib.md5(src_abs.encode("utf-8")).hexdigest() + ) os.makedirs(dest, exist_ok=True) dest = os.path.join(dest, os.path.basename(src_abs)) if os.path.exists(dest) and verify_copy(src, dest): @@ -166,11 +168,15 @@ def _record_directory_structure(folder_path): with open(record_file, "w") as f: for root, dirs, files in os.walk(folder_path): for dir_name in dirs: - relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path) + relative_dir = os.path.relpath( + os.path.join(root, dir_name), folder_path + ) f.write(f"dir:{relative_dir}\n") for file_name in files: if file_name != ".directory_record.txt": - relative_file = os.path.relpath(os.path.join(root, file_name), folder_path) + relative_file = os.path.relpath( + os.path.join(root, file_name), folder_path + ) f.write(f"file:{relative_file}\n") return record_file @@ -185,7 +191,9 @@ def _check_directory_structure(folder_path, record_file): existing_entries.add(f"dir:{relative_dir}") for file_name in files: if file_name != ".directory_record.txt": - relative_file = os.path.relpath(os.path.join(root, file_name), folder_path) + relative_file = os.path.relpath( + os.path.join(root, file_name), folder_path + ) existing_entries.add(f"file:{relative_file}") with open(record_file) as f: recorded_entries = set(f.read().splitlines()) @@ -193,7 +201,12 @@ def _check_directory_structure(folder_path, record_file): def copy_to_local( - src: str, cache_dir=None, filelock=".file.lock", verbose=False, always_recopy=False, use_shm: bool = False + src: str, + cache_dir=None, + filelock=".file.lock", + verbose=False, + always_recopy=False, + use_shm: bool = False, ) -> str: """Copy files/directories from HDFS to local cache with validation. @@ -209,7 +222,9 @@ def copy_to_local( str: Local filesystem path to copied resource """ # Save to a local path for persistence. - local_path = copy_local_path_from_hdfs(src, cache_dir, filelock, verbose, always_recopy) + local_path = copy_local_path_from_hdfs( + src, cache_dir, filelock, verbose, always_recopy + ) # Load into shm to improve efficiency. if use_shm: return copy_to_shm(local_path) @@ -222,7 +237,9 @@ def copy_local_path_from_hdfs( """Deprecated. Please use copy_to_local instead.""" from filelock import FileLock - assert src[-1] != "/", f"Make sure the last char in src is not / because it will cause error. Got {src}" + assert ( + src[-1] != "/" + ), f"Make sure the last char in src is not / because it will cause error. Got {src}" if is_non_local(src): # download from hdfs to local @@ -252,7 +269,9 @@ def copy_local_path_from_hdfs( record_file = os.path.join(local_path, ".directory_record.txt") if not _check_directory_structure(local_path, record_file): if verbose: - print(f"Recopy from {src} to {local_path} due to missing files or directories.") + print( + f"Recopy from {src} to {local_path} due to missing files or directories." + ) shutil.rmtree(local_path, ignore_errors=True) copy(src, local_path) _record_directory_structure(local_path) diff --git a/Agent0/executor_train/verl/verl/utils/fsdp_utils.py b/Agent0/executor_train/verl/verl/utils/fsdp_utils.py index 7465b40..06aad57 100644 --- a/Agent0/executor_train/verl/verl/utils/fsdp_utils.py +++ b/Agent0/executor_train/verl/verl/utils/fsdp_utils.py @@ -27,17 +27,35 @@ from torch.distributed import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp._runtime_utils import _lazy_init -from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, +) from transformers.trainer_pt_utils import get_module_class_from_name from verl.utils.device import get_device_id, get_device_name, get_torch_device if version.parse(torch.__version__) >= version.parse("2.6"): - from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard + from torch.distributed.fsdp import ( + CPUOffloadPolicy, + FSDPModule, + MixedPrecisionPolicy, + fully_shard, + ) elif version.parse(torch.__version__) >= version.parse("2.4"): - from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard + from torch.distributed._composable.fsdp import ( + CPUOffloadPolicy, + FSDPModule, + MixedPrecisionPolicy, + fully_shard, + ) else: - fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = None, None, None, None + fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = ( + None, + None, + None, + None, + ) def init_fn(x: torch.nn.Module): @@ -53,9 +71,17 @@ def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = Non cpu_init_weights = lambda: torch.device("cpu") if use_meta_tensor: if mesh is None: - init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights + init_context = ( + init_empty_weights + if torch.distributed.get_rank() != 0 + else cpu_init_weights + ) else: - init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights + init_context = ( + init_empty_weights + if mesh.get_coordinate()[-1] != 0 + else cpu_init_weights + ) else: init_context = cpu_init_weights return init_context @@ -106,18 +132,24 @@ def lambda_policy_fn(module): and module.weight.requires_grad ) - lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + lambda_policy = functools.partial( + lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn + ) policies.append(lambda_policy) if min_num_params > 0: - size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) + size_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=min_num_params + ) policies.append(size_policy) elif fsdp_transformer_layer_cls_to_wrap is not None: transformer_cls_to_wrap = set() for layer_class in fsdp_transformer_layer_cls_to_wrap: transformer_cls = get_module_class_from_name(module, layer_class) if transformer_cls is None: - raise Exception("Could not find the transformer layer class to wrap in the model.") + raise Exception( + "Could not find the transformer layer class to wrap in the model." + ) else: transformer_cls_to_wrap.add(transformer_cls) @@ -183,7 +215,9 @@ def load_fsdp_model_to_gpu(model: FSDP): if handle._offload_params: continue flat_param = handle.flat_param - handle.flat_param_to(torch.device(f"{get_device_name()}:{device_id}"), non_blocking=True) + handle.flat_param_to( + torch.device(f"{get_device_name()}:{device_id}"), non_blocking=True + ) # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data @@ -240,7 +274,9 @@ def register_empty_parameter(module, name, param): param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad - module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs + ) registered.add(module._parameters[name]) try: @@ -289,7 +325,9 @@ def parallel_load_safetensors(filepath): ckpt_chunks = sorted(safetensors2param.keys()) world_size = dist.get_world_size() size = int(math.ceil(total_files / world_size)) - ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)] + ckpt_chunks = [ + ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size) + ] shard_states = {} device = get_device_id() @@ -307,7 +345,9 @@ def parallel_load_safetensors(filepath): return shard_states -def parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter]): +def parallel_init_module_fn( + module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter] +): """ Generate a function to initialize sub-modules in the `module` with `shard_states` from huggingface checkpoint. @@ -322,7 +362,8 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, tor state2fqn = {} for name, state in itertools.chain( - module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False) + module.named_parameters(remove_duplicate=False), + module.named_buffers(remove_duplicate=False), ): state2fqn.setdefault(state, []).append(name) # remove standalone parameters and buffers @@ -334,7 +375,10 @@ def create_and_sync_state(param_name, state, is_param): assert param_name in shard_states, f"{param_name} not loaded" device = get_device_id() if is_param: - param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) + param = torch.nn.Parameter( + torch.empty_like(state.data, device=device), + requires_grad=state.requires_grad, + ) else: # buffer param = torch.empty_like(state.data, device=device) loaded = shard_states[param_name] @@ -350,7 +394,9 @@ def create_and_sync_state(param_name, state, is_param): return param def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): - param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False)) + param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple( + sub_mod.named_buffers(recurse=False) + ) # param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0]) for name, state in param_and_buffers: if not state.is_meta: @@ -368,7 +414,9 @@ def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): # for shared parameter, we get it from the first time it is created if state in shared: if state not in materialized_states: - materialized_states[state] = create_and_sync_state(fqn, state, is_param) + materialized_states[state] = create_and_sync_state( + fqn, state, is_param + ) else: if fqn in shard_states: shard_states.pop(fqn) @@ -407,7 +455,9 @@ def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg): return nullcontext() -def get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True, rank0_only: bool = True): +def get_fsdp_full_state_dict( + model: torch.nn.Module, offload_to_cpu: bool = True, rank0_only: bool = True +): """ Get the full state dict from an FSDP model. @@ -425,17 +475,27 @@ def get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True if fsdp_version(model) == 1: from torch.distributed.fsdp import FullStateDictConfig, StateDictType - state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only) + state_dict_config = FullStateDictConfig( + offload_to_cpu=offload_to_cpu, rank0_only=rank0_only + ) with get_fsdp_state_ctx( - model, state_type=StateDictType.FULL_STATE_DICT, state_cfg=state_dict_config, optim_cfg=None + model, + state_type=StateDictType.FULL_STATE_DICT, + state_cfg=state_dict_config, + optim_cfg=None, ): state_dict = model.state_dict() return state_dict elif fsdp_version(model) == 2: - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + ) state_dict_config = StateDictOptions( - full_state_dict=True, cpu_offload=offload_to_cpu, broadcast_from_rank0=not rank0_only + full_state_dict=True, + cpu_offload=offload_to_cpu, + broadcast_from_rank0=not rank0_only, ) state_dict = get_model_state_dict(model, options=state_dict_config) return state_dict @@ -443,7 +503,9 @@ def get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True raise NotImplementedError(f"Unknown FSDP version {fsdp_version}") -def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None): +def fsdp2_load_full_state_dict( + model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None +): """ Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the parameters from rank 0 to all other ranks. This function modifies the model in-place. @@ -452,7 +514,10 @@ def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_ model (`torch.nn.Module`): The model to load the state dict into full_state (`dict`): The full state dict to load, can only be on rank 0 """ - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, + ) # To broadcast, it needs to be instantiated in the GPU. if dist.get_rank() == 0: @@ -461,7 +526,9 @@ def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_ model = model.to_empty(device=get_device_id()) cpu_offload = cpu_offload is not None - options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True) + options = StateDictOptions( + full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True + ) set_model_state_dict(model, full_state, options=options) # rotary_emb is not in state_dict, so we need to broadcast it manually @@ -476,7 +543,9 @@ def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_ def apply_fsdp2(model, fsdp_kwargs, config): """model: AutoModelForCausalLM""" - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + assert ( + CPUOffloadPolicy is not None + ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get( @@ -486,7 +555,10 @@ def apply_fsdp2(model, fsdp_kwargs, config): if isinstance(fsdp_transformer_layer_cls_to_wrap, str): fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] - assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None + assert ( + len(fsdp_transformer_layer_cls_to_wrap) > 0 + and fsdp_transformer_layer_cls_to_wrap[0] is not None + ) modules = [] for name, module in model.named_modules(): @@ -497,10 +569,14 @@ def apply_fsdp2(model, fsdp_kwargs, config): for idx, module in enumerate(modules): fully_shard(module, **fsdp_kwargs) - fully_shard(model, **fsdp_kwargs) # fsdp2 will not reshard_after_forward for root module + fully_shard( + model, **fsdp_kwargs + ) # fsdp2 will not reshard_after_forward for root module -def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None): +def fsdp2_clip_grad_norm_( + parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None +): """torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor""" from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm @@ -538,16 +614,22 @@ def __prefix_submodules(module, prefix): peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module) for prefix in prefix_list: for name, submodule in __prefix_submodules(fsdp_module, prefix): - prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.") + prefix = name.replace( + "_fsdp_wrapped_module.base_model.model.", "base_model.model." + ) if name.endswith(".model") or name.endswith(".layers"): continue if fsdp_version(submodule) > 0: with FSDP.summon_full_params(submodule, writeback=False): - sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict()) + sub_lora_params = get_peft_model_state_dict( + peft_model, state_dict=submodule.state_dict() + ) sub_lora_params = { - f"{prefix}.{name}": param.full_tensor().detach().cpu() - if hasattr(param, "full_tensor") - else param.detach().cpu() + f"{prefix}.{name}": ( + param.full_tensor().detach().cpu() + if hasattr(param, "full_tensor") + else param.detach().cpu() + ) for name, param in sub_lora_params.items() } lora_params.update(sub_lora_params) diff --git a/Agent0/executor_train/verl/verl/utils/hdfs_io.py b/Agent0/executor_train/verl/verl/utils/hdfs_io.py index 31edda1..9062657 100644 --- a/Agent0/executor_train/verl/verl/utils/hdfs_io.py +++ b/Agent0/executor_train/verl/verl/utils/hdfs_io.py @@ -113,9 +113,13 @@ def copy(src: str, dst: str, **kwargs) -> bool: def _copy(from_path: str, to_path: str, timeout: int = None) -> bool: if to_path.startswith("hdfs"): if from_path.startswith("hdfs"): - returncode = _run_cmd(_hdfs_cmd(f"-cp -f {from_path} {to_path}"), timeout=timeout) + returncode = _run_cmd( + _hdfs_cmd(f"-cp -f {from_path} {to_path}"), timeout=timeout + ) else: - returncode = _run_cmd(_hdfs_cmd(f"-put -f {from_path} {to_path}"), timeout=timeout) + returncode = _run_cmd( + _hdfs_cmd(f"-put -f {from_path} {to_path}"), timeout=timeout + ) else: if from_path.startswith("hdfs"): returncode = _run_cmd( diff --git a/Agent0/executor_train/verl/verl/utils/kernel/__init__.py b/Agent0/executor_train/verl/verl/utils/kernel/__init__.py index e32d583..4d8acb1 100644 --- a/Agent0/executor_train/verl/verl/utils/kernel/__init__.py +++ b/Agent0/executor_train/verl/verl/utils/kernel/__init__.py @@ -28,4 +28,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/Agent0/executor_train/verl/verl/utils/kernel/kernels.py b/Agent0/executor_train/verl/verl/utils/kernel/kernels.py index a125bac..6f55026 100644 --- a/Agent0/executor_train/verl/verl/utils/kernel/kernels.py +++ b/Agent0/executor_train/verl/verl/utils/kernel/kernels.py @@ -92,10 +92,10 @@ class BackwardEnum: Enum for the backward method. """ - _Total_Fuse_MN = ( - 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight + _Total_Fuse_MN = 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight + _Total_Separate = ( + 1 # Store d_logits, no special requirements for d_hidden & d_weight ) - _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight _Split_Dlogits_N = 2 # split d_logits along its N dimension, aka. vocab_size _Split_Dlogits_M = 3 # split d_logits along its M dimension, aka. num_tokens @@ -118,7 +118,13 @@ def set_backward_method(backward_method: BackwardEnum): @triton.autotune( - configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=3, num_warps=8)], + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=3, + num_warps=8, + ) + ], key=["num_tokens", "hidden_size", "vocab_size"], ) @triton.jit @@ -169,7 +175,9 @@ def efficient_entropy_kernel_general_mainloop( # create pointers for the first blocks of hidden offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_k = tl.arange(0, BLOCK_SIZE_K) - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + hidden_ptrs = hidden_ptr + ( + offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k + ) # load labels for this block labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) @@ -181,9 +189,13 @@ def efficient_entropy_kernel_general_mainloop( _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for n in range(0, num_pid_n): - offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_bn = ( + pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + weight_ptrs = weight_ptr + ( + offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k + ) # iterate over K dimension logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) @@ -191,7 +203,8 @@ def efficient_entropy_kernel_general_mainloop( # load the next block of hidden and weight _hidden = tl.load( hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:, None] < num_tokens), other=0.0, ) # _weight = tl.load(weight_ptrs, @@ -236,13 +249,27 @@ def efficient_entropy_kernel_general_mainloop( offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_max_n = pid_n maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m - tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + tl.store( + maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits) + ) # store entropy accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m - tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits)) - entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m - tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + tl.store( + accu_ptrs, + _accu, + mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits), + ) + entropy_b_ptrs = ( + entropy_b_ptr + + offs_max_n * stride_entropy_b_n + + offs_max_m * stride_entropy_b_m + ) + tl.store( + entropy_b_ptrs, + _entropy_b, + mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits), + ) # store logprobs vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size @@ -254,7 +281,10 @@ def efficient_entropy_kernel_general_mainloop( tl.store(global_logprobs_ptrs, _logprobs, mask=mask) -@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], + key=["num_tokens", "num_splits"], +) @triton.jit def efficient_entropy_triton_kernel_epilogue( max_ptr, @@ -294,16 +324,34 @@ def efficient_entropy_triton_kernel_epilogue( global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n + max_ptrs = ( + max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n + ) - _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + _max = tl.load( + max_ptrs, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) - accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n - _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + accu_ptrs = ( + accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n + ) + _accu = tl.load( + accu_ptrs, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) - entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n + entropy_b_ptrs = ( + entropy_b_ptr + + offs_m[:, None] * stride_entropy_b_m + + offs_n[None, :] * stride_entropy_b_n + ) _entropy_b = tl.load( - entropy_b_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0 + entropy_b_ptrs, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, ) # local reduction @@ -314,7 +362,9 @@ def efficient_entropy_triton_kernel_epilogue( _scale = tl.exp(_max - global_max[:, None]) _coeff = tl.exp(_max_old - global_max) global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) - global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum( + _scale * _entropy_b, axis=1 + ) # store maximum_ptrs = global_max_ptr + offs_m * stride_global_max @@ -322,7 +372,11 @@ def efficient_entropy_triton_kernel_epilogue( # store entropy_b global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b - tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + tl.store( + global_entropy_b_ptr + offs_m * stride_global_entropy_b, + global_entropy_b, + mask=offs_m < num_tokens, + ) # store entropy global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu @@ -342,11 +396,16 @@ def efficient_entropy_triton_kernel_epilogue( global_logprobs_scalar = tl.sum(global_logprobs, axis=0) tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) elif reduction == 2: - global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to( + tl.float32 + ) tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) -@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], + key=["num_tokens", "num_splits"], +) @triton.jit def efficient_entropy_triton_kernel_epilogue_tp( num_tokens, @@ -383,17 +442,23 @@ def efficient_entropy_triton_kernel_epilogue_tp( offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) _reduced_max = tl.load( - reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + offs_n[None, :] * stride_reduced_max_n, + reduced_max_ptr + + offs_m[:, None] * stride_reduced_max_m + + offs_n[None, :] * stride_reduced_max_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0, ) _original_max = tl.load( - original_max_ptr + offs_m[:, None] * stride_original_max_m + offs_n[None, :] * stride_original_max_n, + original_max_ptr + + offs_m[:, None] * stride_original_max_m + + offs_n[None, :] * stride_original_max_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0, ) _accu = tl.load( - accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, + accu_ptr + + offs_m[:, None] * stride_accu_m + + offs_n[None, :] * stride_accu_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0, ) @@ -410,16 +475,32 @@ def efficient_entropy_triton_kernel_epilogue_tp( # update entropy_b _entropy_b = tl.load( - entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n, + entropy_b_ptr + + offs_m[:, None] * stride_entropy_b_m + + offs_n[None, :] * stride_entropy_b_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0, ) - global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum( + _scale * _entropy_b, axis=1 + ) # store - tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) - tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) - tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + tl.store( + global_max_ptr + offs_m * stride_global_max, + global_max, + mask=offs_m < num_tokens, + ) + tl.store( + global_accu_ptr + offs_m * stride_global_accu, + global_accu, + mask=offs_m < num_tokens, + ) + tl.store( + global_entropy_b_ptr + offs_m * stride_global_entropy_b, + global_entropy_b, + mask=offs_m < num_tokens, + ) @triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) @@ -445,21 +526,31 @@ def efficient_entropy_triton_epilogue_tp_update( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) - accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens) + accumulate = tl.load( + accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens + ) - entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens) + entropy_b = tl.load( + entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens + ) entropy_b = tl.fdiv(entropy_b, accumulate) - tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens) + tl.store( + entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens + ) entropy = tl.log(accumulate) + maximum - entropy_b tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens) - logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens) + logprobs = tl.load( + logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens + ) logprobs = maximum + tl.log(accumulate) - logprobs logprobs = -1 * logprobs if reduction == 0: - tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) + tl.store( + logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens + ) elif reduction == 1: logprobs_scalar = tl.sum(logprobs, axis=0) tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) @@ -490,9 +581,13 @@ def efficient_entropy_forward( assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) - _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + _world_size = ( + 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + ) - if dist_process_group is not None and not hasattr(efficient_entropy_forward, "_initialized"): + if dist_process_group is not None and not hasattr( + efficient_entropy_forward, "_initialized" + ): global _dedicated_stream, _dedicated_events _dedicated_stream = get_torch_device().Stream(hidden.device) _dedicated_events = [get_torch_device().Event() for _ in range(2)] @@ -507,9 +602,13 @@ def efficient_entropy_forward( if REDUCTION == EntropyReductionEnum._None: if dist_process_group is None: - logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + logprobs = torch.empty( + (num_tokens,), device=hidden.device, dtype=torch.float32 + ) else: - logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) + logprobs = torch.zeros( + (num_tokens,), device=hidden.device, dtype=torch.float32 + ) elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean): logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) else: @@ -519,24 +618,38 @@ def efficient_entropy_forward( assert logprobs.is_contiguous() and entropy.is_contiguous() maximum = torch.empty_like(entropy) - accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32) + accumulate_and_entropy_b = torch.empty( + (num_tokens * 2,), device=hidden.device, dtype=torch.float32 + ) accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens) accumulate = accumulate_and_entropy_b_view[0, :] entropy_b = accumulate_and_entropy_b_view[1, :] - assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous() + assert ( + maximum.is_contiguous() + and accumulate.is_contiguous() + and entropy_b.is_contiguous() + ) vocab_per_split = 1024 assert vocab_per_split % 128 == 0 num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _max = torch.empty( + (num_tokens, num_splits), device=hidden.device, dtype=torch.float32 + ) + _accu = torch.empty( + (num_tokens, num_splits), device=hidden.device, dtype=torch.float32 + ) + _entropy_b = torch.empty( + (num_tokens, num_splits), device=hidden.device, dtype=torch.float32 + ) if REDUCTION == EntropyReductionEnum._None: _logprobs = logprobs else: - _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + _logprobs = torch.empty( + (num_tokens,), device=hidden.device, dtype=torch.float32 + ) assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda @@ -641,7 +754,9 @@ def epilogue_grid(meta): ) get_torch_device().current_stream().wait_event(_dedicated_events[1]) - dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group) + dist.all_reduce( + accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group + ) # update logprobs & entropy efficient_entropy_triton_epilogue_tp_update[epilogue_grid]( @@ -667,7 +782,12 @@ def epilogue_grid(meta): @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + }, num_stages=3, num_warps=8, ) @@ -737,7 +857,9 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( maximum_ptrs = maximum_ptr + offs_am * stride_maximum maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) accu_ptrs = accu_ptr + offs_am * stride_accu - accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu = tl.load( + accu_ptrs, mask=offs_am < num_tokens, other=1e-6 + ) # epsilon to avoid division by zero accu_rcp = tl.fdiv(1.0, accu) d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy @@ -756,21 +878,34 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + hidden_ptrs = hidden_ptr + ( + offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k + ) # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + weight_ptrs = weight_ptr + ( + offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k + ) labels_ptrs = labels_ptr + offs_am * stride_labels labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) - d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k + d_hidden_ptrs = ( + d_hidden_ptr + + offs_am[:, None] * stride_d_hidden_m + + offs_k[None, :] * stride_d_hidden_k + ) # d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n - d_weight_ptrs = d_weight_ptr + offs_bn[:, None] * stride_d_weight_n + offs_k[None, :] * stride_d_weight_k + d_weight_ptrs = ( + d_weight_ptr + + offs_bn[:, None] * stride_d_weight_n + + offs_k[None, :] * stride_d_weight_k + ) logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): _hidden = tl.load( hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:, None] < num_tokens), other=0.0, ) # _weight = tl.load(weight_ptrs, @@ -778,7 +913,8 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( # other=0.0) _weight = tl.load( weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[:, None] < vocab_size), other=0.0, ) @@ -796,7 +932,11 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + d_logits += ( + d_entropy[:, None] + * (-exp_logits * accu_rcp[:, None]) + * (logits - entropy_b[:, None]) + ) # scale d_logits by temperature d_logits *= rcp_temperature @@ -805,7 +945,8 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): _hidden = tl.load( hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:, None] < num_tokens), other=0.0, ) # _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) @@ -816,7 +957,8 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( tl.atomic_add( d_weight_ptrs, _d_weight, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[:, None] < vocab_size), ) # _weight = tl.load(weight_ptrs, @@ -825,14 +967,16 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( # _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) _weight = tl.load( weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[:, None] < vocab_size), other=0.0, ) _d_hidden = tl.dot(d_logits, _weight.to(tl.float32)) tl.atomic_add( d_hidden_ptrs, _d_hidden, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:, None] < num_tokens), ) hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k @@ -844,7 +988,12 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + }, num_stages=3, num_warps=8, ), @@ -897,12 +1046,22 @@ def efficient_entropy_backward_kernel_d_hidden( offs_k = tl.arange(0, BLOCK_SIZE_K) result_offs_k = pid_k * BLOCK_SIZE_K + offs_k - maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) - accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) + maximum = tl.load( + maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0 + ) + accu = tl.load( + accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6 + ) accu_rcp = tl.fdiv(1.0, accu) - d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) + d_entropy = tl.load( + d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0 + ) if reduction == 0: - d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) + d_logprobs = tl.load( + d_logprobs_ptr + offs_m * stride_d_logprobs, + mask=offs_m < num_tokens, + other=0.0, + ) elif reduction == 1: d_logprobs = tl.load(d_logprobs_ptr) d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) @@ -911,28 +1070,38 @@ def efficient_entropy_backward_kernel_d_hidden( d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) d_logprobs = -1 * d_logprobs - entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) - labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) + entropy_b = tl.load( + entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0 + ) + labels = tl.load( + labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0 + ) # iterate over vocab_size d_hidden = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) for n in range(0, tl.cdiv(vocab_size, BLOCK_SIZE_N)): offs_n = n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + hidden_ptrs = hidden_ptr + ( + offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k + ) + weight_ptrs = weight_ptr + ( + offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k + ) # iterate over hidden_size to get logits logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): _hidden = tl.load( hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_m[:, None] < num_tokens), other=0.0, ) _weight = tl.load( weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_n[:, None] < vocab_size), other=0.0, ) @@ -948,21 +1117,32 @@ def efficient_entropy_backward_kernel_d_hidden( mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + d_logits += ( + d_entropy[:, None] + * (-exp_logits * accu_rcp[:, None]) + * (logits - entropy_b[:, None]) + ) # scale d_logits d_logits *= rcp_temperature # calculate d_hidden - weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k) + weight_ptrs = weight_ptr + ( + offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k + ) _weight = tl.load( - weight_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_n[:, None] < vocab_size), other=0.0 + weight_ptrs, + mask=(result_offs_k[None, :] < hidden_size) + & (offs_n[:, None] < vocab_size), + other=0.0, ) d_hidden = tl.dot(d_logits.to(weight_ptr.dtype.element_ty), _weight, d_hidden) # write back tl.store( - d_hidden_ptr + offs_m[:, None] * stride_d_hidden_m + result_offs_k[None, :] * stride_d_hidden_k, + d_hidden_ptr + + offs_m[:, None] * stride_d_hidden_m + + result_offs_k[None, :] * stride_d_hidden_k, d_hidden, mask=(offs_m[:, None] < num_tokens) & (result_offs_k[None, :] < hidden_size), ) @@ -971,7 +1151,12 @@ def efficient_entropy_backward_kernel_d_hidden( @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + }, num_stages=3, num_warps=8, ), @@ -1025,12 +1210,24 @@ def efficient_entropy_backward_kernel_d_weight( for m in range(0, tl.cdiv(num_tokens, BLOCK_SIZE_M)): offs_m = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) - accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) + maximum = tl.load( + maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0 + ) + accu = tl.load( + accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6 + ) accu_rcp = tl.fdiv(1.0, accu) - d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) + d_entropy = tl.load( + d_entropy_ptr + offs_m * stride_d_entropy, + mask=offs_m < num_tokens, + other=0.0, + ) if reduction == 0: - d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) + d_logprobs = tl.load( + d_logprobs_ptr + offs_m * stride_d_logprobs, + mask=offs_m < num_tokens, + other=0.0, + ) elif reduction == 1: d_logprobs = tl.load(d_logprobs_ptr) d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) @@ -1039,22 +1236,34 @@ def efficient_entropy_backward_kernel_d_weight( d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) d_logprobs = -1 * d_logprobs - entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) - labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) + entropy_b = tl.load( + entropy_b_ptr + offs_m * stride_entropy_b, + mask=offs_m < num_tokens, + other=0.0, + ) + labels = tl.load( + labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0 + ) - hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + hidden_ptrs = hidden_ptr + ( + offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k + ) + weight_ptrs = weight_ptr + ( + offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k + ) logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): _hidden = tl.load( hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_m[:, None] < num_tokens), other=0.0, ) _weight = tl.load( weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_n[:, None] < vocab_size), other=0.0, ) @@ -1069,19 +1278,32 @@ def efficient_entropy_backward_kernel_d_weight( mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + d_logits += ( + d_entropy[:, None] + * (-exp_logits * accu_rcp[:, None]) + * (logits - entropy_b[:, None]) + ) d_logits *= rcp_temperature - hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k) + hidden_ptrs = hidden_ptr + ( + offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k + ) _hidden = tl.load( - hidden_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_m[:, None] < num_tokens), other=0.0 + hidden_ptrs, + mask=(result_offs_k[None, :] < hidden_size) + & (offs_m[:, None] < num_tokens), + other=0.0, + ) + d_weight = tl.dot( + d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight ) - d_weight = tl.dot(d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight) # write back tl.store( - d_weight_ptr + offs_n[:, None] * stride_d_weight_n + result_offs_k[None, :] * stride_d_weight_k, + d_weight_ptr + + offs_n[:, None] * stride_d_weight_n + + result_offs_k[None, :] * stride_d_weight_k, d_weight, mask=(offs_n[:, None] < vocab_size) & (result_offs_k[None, :] < hidden_size), ) @@ -1091,7 +1313,12 @@ def efficient_entropy_backward_kernel_d_weight( @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + }, num_stages=3, num_warps=8, ), @@ -1158,7 +1385,9 @@ def efficient_entropy_backward_kernel_general_d_logits( maximum_ptrs = maximum_ptr + offs_am * stride_maximum maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) accu_ptrs = accu_ptr + offs_am * stride_accu - accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu = tl.load( + accu_ptrs, mask=offs_am < num_tokens, other=1e-6 + ) # epsilon to avoid division by zero accu_rcp = tl.fdiv(1.0, accu) d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy @@ -1177,9 +1406,13 @@ def efficient_entropy_backward_kernel_general_d_logits( entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + hidden_ptrs = hidden_ptr + ( + offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k + ) # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + weight_ptrs = weight_ptr + ( + offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k + ) labels_ptrs = labels_ptr + offs_am * stride_labels labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) @@ -1187,7 +1420,8 @@ def efficient_entropy_backward_kernel_general_d_logits( for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): _hidden = tl.load( hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:, None] < num_tokens), other=0.0, ) # _weight = tl.load(weight_ptrs, @@ -1195,7 +1429,8 @@ def efficient_entropy_backward_kernel_general_d_logits( # other=0.0) _weight = tl.load( weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[:, None] < vocab_size), other=0.0, ) @@ -1213,13 +1448,21 @@ def efficient_entropy_backward_kernel_general_d_logits( mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + d_logits += ( + d_entropy[:, None] + * (-exp_logits * accu_rcp[:, None]) + * (logits - entropy_b[:, None]) + ) # scale d_logits by temperature d_logits *= rcp_temperature # store d_logits - d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n + d_logits_ptrs = ( + d_logits_ptr + + offs_am[:, None] * stride_d_logits_m + + offs_bn[None, :] * stride_d_logits_n + ) tl.store( d_logits_ptrs, d_logits, # will be implicitly converted to d_logits_ptrs.dtype.element_ty @@ -1230,7 +1473,12 @@ def efficient_entropy_backward_kernel_general_d_logits( @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + }, num_stages=3, num_warps=8, ), @@ -1284,15 +1532,27 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( pid_n = (pid % num_pid_in_group) // group_size_m offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_bn = ( + split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ) offs_k = tl.arange(0, BLOCK_SIZE_K) - maximum = tl.load(maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0) - accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6) + maximum = tl.load( + maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0 + ) + accu = tl.load( + accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6 + ) accu_rcp = tl.fdiv(1.0, accu) - d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0) + d_entropy = tl.load( + d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0 + ) if reduction == 0: - d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0) + d_logprobs = tl.load( + d_logprobs_ptr + offs_am * stride_d_logprobs, + mask=offs_am < num_tokens, + other=0.0, + ) elif reduction == 1: d_logprobs = tl.load(d_logprobs_ptr) d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) @@ -1300,23 +1560,33 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) d_logprobs = -1 * d_logprobs - entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0) - labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0) + entropy_b = tl.load( + entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0 + ) + labels = tl.load( + labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0 + ) - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + hidden_ptrs = hidden_ptr + ( + offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k + ) + weight_ptrs = weight_ptr + ( + offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k + ) vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): _hidden = tl.load( hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:, None] < num_tokens), other=0.0, ) _weight = tl.load( weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_right_bound), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[:, None] < vocab_right_bound), other=0.0, ) logits = tl.dot(_hidden, _weight.trans(), logits) @@ -1329,7 +1599,11 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + d_logits += ( + d_entropy[:, None] + * (-exp_logits * accu_rcp[:, None]) + * (logits - entropy_b[:, None]) + ) d_logits *= rcp_temperature @@ -1338,7 +1612,11 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split) tl.store( - d_logits_ptr + offs_am[:, None] * stride_d_logits_m + result_offs_n[None, :] * stride_d_logits_n, d_logits, mask + d_logits_ptr + + offs_am[:, None] * stride_d_logits_m + + result_offs_n[None, :] * stride_d_logits_n, + d_logits, + mask, ) @@ -1366,7 +1644,9 @@ def efficient_entropy_backward( assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) - _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + _world_size = ( + 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + ) num_tokens, hidden_size = hidden.shape num_tokens = labels.shape[0] @@ -1409,7 +1689,10 @@ def efficient_entropy_backward( if _config._backward == BackwardEnum._Total_Fuse_MN: # --- Triton doesn't materialize d_logits at all. Split tiles at the perspective of d_logits. def mainloop_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + return ( + triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) + * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]), + ) efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid]( num_tokens, @@ -1445,13 +1728,18 @@ def mainloop_grid(meta): ) elif _config._backward == BackwardEnum._Total_Separate: - _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype).contiguous() + _d_logits = torch.empty( + (num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype + ).contiguous() assert _d_logits.is_contiguous() if _config._use_triton: def d_logits_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + return ( + triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) + * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]), + ) efficient_entropy_backward_kernel_general_d_logits[d_logits_grid]( num_tokens, @@ -1492,11 +1780,16 @@ def d_logits_grid(meta): vocab_per_split = 9504 num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous() + _d_logits = torch.empty( + (num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype + ).contiguous() assert _d_logits.is_contiguous() def d_logits_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]),) + return ( + triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) + * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]), + ) for split_idx in range(num_splits): efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid]( @@ -1532,22 +1825,40 @@ def d_logits_grid(meta): ) if split_idx == (num_splits - 1): - vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split + vocab_right_bound = ( + min((split_idx + 1) * vocab_per_split, vocab_size) + - split_idx * vocab_per_split + ) _d_logits = _d_logits[:, :vocab_right_bound].contiguous() if split_idx == 0: torch.matmul( - _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], out=d_hidden + _d_logits, + weight[ + split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, + :, + ], + out=d_hidden, ) else: d_hidden += torch.matmul( - _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :] + _d_logits, + weight[ + split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, + :, + ], ) torch.matmul( - _d_logits.T, hidden, out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :] + _d_logits.T, + hidden, + out=d_weight[ + split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, : + ], ) elif _config._backward == BackwardEnum._Split_Dlogits_M: - raise NotImplementedError("BackwardEnum._Split_Dlogits_M is not implemented yet") + raise NotImplementedError( + "BackwardEnum._Split_Dlogits_M is not implemented yet" + ) return d_hidden, d_weight diff --git a/Agent0/executor_train/verl/verl/utils/kernel/linear_cross_entropy.py b/Agent0/executor_train/verl/verl/utils/kernel/linear_cross_entropy.py index 733a815..a011d95 100644 --- a/Agent0/executor_train/verl/verl/utils/kernel/linear_cross_entropy.py +++ b/Agent0/executor_train/verl/verl/utils/kernel/linear_cross_entropy.py @@ -63,22 +63,32 @@ def forward( typing.List[torch.Tensor]: _description_ """ - assert isinstance(temperature, float), f"temperature must be a float, but got {type(temperature)}" - assert isinstance(reduction, str), f"reduction must be a str, but got {type(reduction)}" + assert isinstance( + temperature, float + ), f"temperature must be a float, but got {type(temperature)}" + assert isinstance( + reduction, str + ), f"reduction must be a str, but got {type(reduction)}" with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) original_hidden_shape = hidden.shape if len(hidden.shape) != 2: - hidden = hidden.view(-1, hidden.shape[-1]) # (batch_size * num_tokens, hidden_size) + hidden = hidden.view( + -1, hidden.shape[-1] + ) # (batch_size * num_tokens, hidden_size) if len(labels.shape) != 1: labels = labels.view(-1) - logprobs, entropy, _maximum, _accumulate, _entropy_b = kernels.efficient_entropy_forward( - hidden, weight, labels, REDUCTION, temperature, dist_process_group + logprobs, entropy, _maximum, _accumulate, _entropy_b = ( + kernels.efficient_entropy_forward( + hidden, weight, labels, REDUCTION, temperature, dist_process_group + ) ) - ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b) + ctx.save_for_backward( + hidden, weight, labels, _maximum, _accumulate, _entropy_b + ) ctx.original_hidden_shape = original_hidden_shape ctx.REDUCTION = REDUCTION ctx.dist_process_group = dist_process_group @@ -87,9 +97,13 @@ def forward( return logprobs, entropy @staticmethod - def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> list[torch.Tensor]: + def backward( + ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor + ) -> list[torch.Tensor]: with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): - (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors + (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ( + ctx.saved_tensors + ) REDUCTION = ctx.REDUCTION dist_process_group = ctx.dist_process_group should_return_fp32_grad = ctx.should_return_fp32_grad diff --git a/Agent0/executor_train/verl/verl/utils/logger/aggregate_logger.py b/Agent0/executor_train/verl/verl/utils/logger/aggregate_logger.py index d29698a..d9fb5b9 100644 --- a/Agent0/executor_train/verl/verl/utils/logger/aggregate_logger.py +++ b/Agent0/executor_train/verl/verl/utils/logger/aggregate_logger.py @@ -64,7 +64,12 @@ class DecoratorLoggerBase: """ def __init__( - self, role: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0, log_only_rank_0: bool = True + self, + role: str, + logger: logging.Logger = None, + level=logging.DEBUG, + rank: int = 0, + log_only_rank_0: bool = True, ): self.role = role self.logger = logger @@ -109,7 +114,9 @@ def print_with_rank(message: str, rank: int = 0, log_only_rank_0: bool = False): print(f"[Rank {rank}] {message}", flush=True) -def print_with_rank_and_timer(message: str, rank: int = 0, log_only_rank_0: bool = False): +def print_with_rank_and_timer( + message: str, rank: int = 0, log_only_rank_0: bool = False +): """_summary_ Print a message with rank information and a timestamp. This function prints the message only if `log_only_rank_0` is False or if the rank is 0. @@ -125,7 +132,13 @@ def print_with_rank_and_timer(message: str, rank: int = 0, log_only_rank_0: bool print(message, flush=True) -def log_with_rank(message: str, rank, logger: logging.Logger, level=logging.INFO, log_only_rank_0: bool = False): +def log_with_rank( + message: str, + rank, + logger: logging.Logger, + level=logging.INFO, + log_only_rank_0: bool = False, +): """_summary_ Log a message with rank information using a logger. This function logs the message only if `log_only_rank_0` is False or if the rank is 0. diff --git a/Agent0/executor_train/verl/verl/utils/megatron/dist_checkpointing.py b/Agent0/executor_train/verl/verl/utils/megatron/dist_checkpointing.py index d95752a..146324c 100644 --- a/Agent0/executor_train/verl/verl/utils/megatron/dist_checkpointing.py +++ b/Agent0/executor_train/verl/verl/utils/megatron/dist_checkpointing.py @@ -51,6 +51,8 @@ def load_dist_checkpointing(sharded_state_dict, ckpt_dir): ) # Load model sharded state dicts - state_dict = dist_checkpointing.load(sharded_state_dict, ckpt_dir, sharded_strategy=load_strategy) + state_dict = dist_checkpointing.load( + sharded_state_dict, ckpt_dir, sharded_strategy=load_strategy + ) return state_dict diff --git a/Agent0/executor_train/verl/verl/utils/megatron/memory.py b/Agent0/executor_train/verl/verl/utils/megatron/memory.py index bc62d42..08f891b 100644 --- a/Agent0/executor_train/verl/verl/utils/megatron/memory.py +++ b/Agent0/executor_train/verl/verl/utils/megatron/memory.py @@ -22,7 +22,12 @@ def __init__(self, numel, numel_padded, dtype): self.numel = numel self.numel_padded = numel_padded self.dtype = dtype - self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_id(), requires_grad=False) + self.data = torch.zeros( + self.numel_padded, + dtype=self.dtype, + device=get_device_id(), + requires_grad=False, + ) def zero(self): """Reset the buffer to zero.""" diff --git a/Agent0/executor_train/verl/verl/utils/megatron/optimizer.py b/Agent0/executor_train/verl/verl/utils/megatron/optimizer.py index 100c161..0caad7a 100644 --- a/Agent0/executor_train/verl/verl/utils/megatron/optimizer.py +++ b/Agent0/executor_train/verl/verl/utils/megatron/optimizer.py @@ -14,7 +14,9 @@ # limitations under the License. from megatron.core.optimizer import OptimizerConfig -from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native +from megatron.core.optimizer import ( + get_megatron_optimizer as get_megatron_optimizer_native, +) from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler @@ -50,7 +52,9 @@ def get_megatron_optimizer_param_scheduler( if config.get("lr_warmup_steps_ratio", None) is not None and ( config.get("lr_warmup_steps", None) is None or config.lr_warmup_steps <= 0 ): - config.lr_warmup_steps = int(config.lr_warmup_steps_ratio * config.lr_decay_steps) + config.lr_warmup_steps = int( + config.lr_warmup_steps_ratio * config.lr_decay_steps + ) opt_param_scheduler = OptimizerParamScheduler( optimizer, diff --git a/Agent0/executor_train/verl/verl/utils/megatron/pipeline_parallel.py b/Agent0/executor_train/verl/verl/utils/megatron/pipeline_parallel.py index 50ba697..b33fcc9 100644 --- a/Agent0/executor_train/verl/verl/utils/megatron/pipeline_parallel.py +++ b/Agent0/executor_train/verl/verl/utils/megatron/pipeline_parallel.py @@ -27,14 +27,17 @@ def compute_transformers_input_shapes(batches, meta_info): for model_inputs in batches: input_ids = model_inputs["input_ids"] attention_mask = model_inputs["attention_mask"] - input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1) + input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[ + 0 + ] # (total_nnz, 1) if meta_info["sequence_parallel"]: input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad) # compute shapes for model_inputs input_shapes.append( torch.Size( [ - input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), + input_ids_rmpad.shape[0] + // mpu.get_tensor_model_parallel_world_size(), 1, meta_info["hidden_size"], ] @@ -42,7 +45,9 @@ def compute_transformers_input_shapes(batches, meta_info): ) else: # compute shapes for model_inputs - input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info["hidden_size"]])) + input_shapes.append( + torch.Size([input_ids_rmpad.shape[0], 1, meta_info["hidden_size"]]) + ) return input_shapes diff --git a/Agent0/executor_train/verl/verl/utils/megatron/sequence_parallel.py b/Agent0/executor_train/verl/verl/utils/megatron/sequence_parallel.py index 52fda9b..3115f45 100644 --- a/Agent0/executor_train/verl/verl/utils/megatron/sequence_parallel.py +++ b/Agent0/executor_train/verl/verl/utils/megatron/sequence_parallel.py @@ -39,7 +39,11 @@ def pad_to_sequence_parallel(unpad_tokens: torch.Tensor): total_nnz = unpad_tokens.shape[0] sp_world_size = mpu.get_tensor_model_parallel_world_size() - pad_size = 0 if total_nnz % sp_world_size == 0 else sp_world_size - total_nnz % sp_world_size + pad_size = ( + 0 + if total_nnz % sp_world_size == 0 + else sp_world_size - total_nnz % sp_world_size + ) if pad_size > 0: if unpad_tokens.ndim == 1: @@ -47,6 +51,8 @@ def pad_to_sequence_parallel(unpad_tokens: torch.Tensor): elif unpad_tokens.ndim == 2: unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size)) else: - raise NotImplementedError(f"Padding dim {unpad_tokens.ndim()} is not supported") + raise NotImplementedError( + f"Padding dim {unpad_tokens.ndim()} is not supported" + ) return unpad_tokens diff --git a/Agent0/executor_train/verl/verl/utils/megatron/tensor_parallel.py b/Agent0/executor_train/verl/verl/utils/megatron/tensor_parallel.py index d4a99b9..3295c7a 100644 --- a/Agent0/executor_train/verl/verl/utils/megatron/tensor_parallel.py +++ b/Agent0/executor_train/verl/verl/utils/megatron/tensor_parallel.py @@ -114,21 +114,35 @@ def mul_reduce(a, b): return (a * b).sum(dim=-1, keepdim=True) logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values - dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group()) + dist.all_reduce( + logits_max, + op=dist.ReduceOp.MAX, + group=mpu.get_tensor_model_parallel_group(), + ) normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max normalized_exp_logits = normalized_vocab_parallel_logits.exp_() normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True) - dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group()) + dist.all_reduce( + normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group() + ) softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits) sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits) - dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group()) - entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits - ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits) + dist.all_reduce( + sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group() + ) + entropy = ( + logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits + ) + ctx.save_for_backward( + vocab_parallel_logits, softmax_logits, sum_softmax_times_logits + ) return entropy.squeeze(dim=-1) @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors + vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ( + ctx.saved_tensors + ) # reuse softmax_logits as grad vocab_parallel_logits.sub_(sum_softmax_times_logits) softmax_logits.mul_(vocab_parallel_logits) @@ -155,10 +169,14 @@ def vocab_parallel_log_probs_from_logits(logits, labels): """TODO(zhangchi.usc1992): We may change the implementation later""" from megatron.core import tensor_parallel - return -tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits=logits, target=labels) + return -tensor_parallel.vocab_parallel_cross_entropy( + vocab_parallel_logits=logits, target=labels + ) -def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length): +def vocab_parallel_log_probs_from_logits_response_rmpad( + input_ids, attention_mask, logits_rmpad, response_length +): """Similar to log_probs_from_logits_response_rmpad, but the logits_rmpad is now spliited across tensor parallel region. This will further reduce the peak memory usage during training @@ -173,14 +191,21 @@ def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mas from flash_attn.bert_padding import pad_input, unpad_input batch_size, seqlen = input_ids.shape - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask=attention_mask + ) input_ids_rmpad = input_ids_rmpad.squeeze(-1) input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) full_log_probs_rmpad = vocab_parallel_log_probs_from_logits( logits=logits_rmpad, labels=input_ids_rmpad_rolled ) # (total_nnz,) full_output = pad_input( - hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + hidden_states=full_log_probs_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, ) - output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] + output = full_output.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # [batch_size, response_length] return output diff --git a/Agent0/executor_train/verl/verl/utils/megatron_utils.py b/Agent0/executor_train/verl/verl/utils/megatron_utils.py index 2fc7437..3b7b01a 100644 --- a/Agent0/executor_train/verl/verl/utils/megatron_utils.py +++ b/Agent0/executor_train/verl/verl/utils/megatron_utils.py @@ -57,16 +57,18 @@ def get_model( mpu.get_pipeline_model_parallel_world_size() > 1 and mpu.get_virtual_pipeline_model_parallel_world_size() is not None ): - assert model_type != ModelType.encoder_and_decoder, ( - "Interleaved schedule not supported for model with both encoder and decoder" - ) + assert ( + model_type != ModelType.encoder_and_decoder + ), "Interleaved schedule not supported for model with both encoder and decoder" model = [] for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()): mpu.set_virtual_pipeline_model_parallel_rank(i) # Set pre_process and post_process only after virtual rank is set. pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() - this_model = model_provider_func(pre_process=pre_process, post_process=post_process) + this_model = model_provider_func( + pre_process=pre_process, post_process=post_process + ) this_model.model_type = model_type model.append(this_model) mpu.set_virtual_pipeline_model_parallel_rank(0) @@ -77,9 +79,9 @@ def get_model( add_decoder = True if model_type == ModelType.encoder_and_decoder: if mpu.get_pipeline_model_parallel_world_size() > 1: - assert mpu.get_pipeline_model_parallel_split_rank() is not None, ( - "Split rank needs to be specified for model with both encoder and decoder" - ) + assert ( + mpu.get_pipeline_model_parallel_split_rank() is not None + ), "Split rank needs to be specified for model with both encoder and decoder" rank = mpu.get_pipeline_model_parallel_rank() split_rank = mpu.get_pipeline_model_parallel_split_rank() world_size = mpu.get_pipeline_model_parallel_world_size() @@ -88,10 +90,15 @@ def get_model( add_encoder = mpu.is_pipeline_stage_before_split() add_decoder = mpu.is_pipeline_stage_after_split() model = model_provider_func( - pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, ) else: - model = model_provider_func(pre_process=pre_process, post_process=post_process) + model = model_provider_func( + pre_process=pre_process, post_process=post_process + ) model.model_type = model_type if not isinstance(model, list): @@ -103,7 +110,9 @@ def get_model( # are set for all params so the optimizer can use them. for model_module in model: for param in model_module.parameters(): - tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes( + param + ) # Print number of parameters. if mpu.get_data_parallel_rank() == 0: @@ -111,7 +120,12 @@ def get_model( " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(), - sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]), + sum( + [ + sum([p.nelement() for p in model_module.parameters()]) + for model_module in model + ] + ), ), flush=True, ) @@ -172,7 +186,11 @@ def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerC print(f"megatron config {megatron_config}") dt = PrecisionType.to_dtype(megatron_config.params_dtype) print(f"pipeline_dtype=megatron_config {dt}") - qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) + qkv_bias = ( + True + if "Qwen2ForCausalLM" in hf_config.architectures + else getattr(hf_config, "attention_bias", False) + ) overlap_p2p_comm = ( mpu.get_virtual_pipeline_model_parallel_world_size() is not None and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 @@ -264,16 +282,24 @@ def offload_megatron_model_to_cpu(models): """ for model_chunk in models: if isinstance(model_chunk, DDP): - model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] for buffers in model_chunk_all_buffers: for buffer in buffers: # offload parameters if buffer.param_data.storage().size() > 0: - buffer.param_data.cpu_data = buffer.param_data.data.cpu().pin_memory() + buffer.param_data.cpu_data = ( + buffer.param_data.data.cpu().pin_memory() + ) buffer.param_data_size = buffer.param_data.storage().size() buffer.param_data.storage().resize_(0) - assert buffer.param_data_size == buffer.param_data.cpu_data.storage().size() + assert ( + buffer.param_data_size + == buffer.param_data.cpu_data.storage().size() + ) if buffer.grad_data.storage().size() > 0: # if the grad_data size is already zero, we assume that it is already offloaded @@ -293,7 +319,10 @@ def offload_megatron_model_to_cpu(models): def load_megatron_model_to_gpu(models, load_grad=True): for model_chunk in models: if isinstance(model_chunk, DDP): - model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + model_chunk_all_buffers = [ + model_chunk.buffers, + model_chunk.expert_parallel_buffers, + ] for buffers in model_chunk_all_buffers: for buffer in buffers: # sometimes, we don't want to load grad for pure inference @@ -304,7 +333,9 @@ def load_megatron_model_to_gpu(models, load_grad=True): if buffer.param_data.storage().size() == 0: buffer.param_data.storage().resize_(buffer.param_data_size) # copy data from cpu to cuda - buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True) + buffer.param_data.copy_( + buffer.param_data.cpu_data, non_blocking=True + ) else: # we need this for ref module device_id = get_device_id() @@ -472,7 +503,9 @@ def convert_qkv_shard(full_tensor, q_name, k_name, v_name): q_shard_list = [] k_shard_list = [] v_shard_list = [] - hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + hidden_size_per_head = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) if config.num_key_value_heads >= tp_size: q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size @@ -520,7 +553,9 @@ def convert_gate_up_shard(full_tensor, gate_name, up_name): gate_weight_list = [] up_weight_list = [] for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] + gate_up_weight_tp = full_tensor[ + intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1) + ] gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] gate_weight_list.append(gate_weight_tp) @@ -540,7 +575,9 @@ def convert_gate_up_shard(full_tensor, gate_name, up_name): new_params[f"model.layers.{layer_number}.self_attn.o_proj.weight"] = param elif component == "linear_qkv" and not isinstance(param, list): if param_type == "layer_norm_weight": - new_params[f"model.layers.{layer_number}.input_layernorm.weight"] = param + new_params[f"model.layers.{layer_number}.input_layernorm.weight"] = ( + param + ) else: if convert_qkv_gate_up_by_trunk_concat: convert_qkv_shard( @@ -550,16 +587,26 @@ def convert_gate_up_shard(full_tensor, gate_name, up_name): f"model.layers.{layer_number}.self_attn.v_proj.{param_type}", ) else: - new_params[f"model.layers.{layer_number}.self_attn.qkv_proj.{param_type}"] = param + new_params[ + f"model.layers.{layer_number}.self_attn.qkv_proj.{param_type}" + ] = param elif component == "q_layernorm" or component == "k_layernorm": hf_component = component.replace("layer", "") - new_params[f"model.layers.{layer_number}.self_attn.{hf_component}.weight"] = param + new_params[ + f"model.layers.{layer_number}.self_attn.{hf_component}.weight" + ] = param else: assert isinstance(param, list) and len(param) == 3 assert param_type == "weight" or param_type == "bias" - new_params[f"model.layers.{layer_number}.self_attn.q_proj.{param_type}"] = param[0] - new_params[f"model.layers.{layer_number}.self_attn.k_proj.{param_type}"] = param[1] - new_params[f"model.layers.{layer_number}.self_attn.v_proj.{param_type}"] = param[2] + new_params[f"model.layers.{layer_number}.self_attn.q_proj.{param_type}"] = ( + param[0] + ) + new_params[f"model.layers.{layer_number}.self_attn.k_proj.{param_type}"] = ( + param[1] + ) + new_params[f"model.layers.{layer_number}.self_attn.v_proj.{param_type}"] = ( + param[2] + ) elif "mlp" in name: splitted_name = name.split(".") layer_number = splitted_name[2] @@ -567,7 +614,9 @@ def convert_gate_up_shard(full_tensor, gate_name, up_name): param_type = splitted_name[5] if component == "linear_fc1" and not isinstance(param, list): if param_type == "layer_norm_weight": - new_params[f"model.layers.{layer_number}.post_attention_layernorm.weight"] = param + new_params[ + f"model.layers.{layer_number}.post_attention_layernorm.weight" + ] = param elif param_type == "weight": if convert_qkv_gate_up_by_trunk_concat: convert_gate_up_shard( @@ -576,7 +625,9 @@ def convert_gate_up_shard(full_tensor, gate_name, up_name): f"model.layers.{layer_number}.mlp.up_proj.weight", ) else: - new_params[f"model.layers.{layer_number}.mlp.gate_up_proj.weight"] = param + new_params[ + f"model.layers.{layer_number}.mlp.gate_up_proj.weight" + ] = param elif component == "linear_fc1" and isinstance(param, list): assert len(param) == 2 assert param_type == "weight" or param_type == "bias" @@ -605,7 +656,9 @@ def broadcast_from_megatron_pp(tensor: torch.Tensor): tensor_spec = None tensor_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() torch.distributed.all_gather_object( - object_list=tensor_spec_output, obj=tensor_spec, group=mpu.get_pipeline_model_parallel_group() + object_list=tensor_spec_output, + obj=tensor_spec, + group=mpu.get_pipeline_model_parallel_group(), ) # find the src rank target_tensor_spec = None @@ -619,20 +672,30 @@ def broadcast_from_megatron_pp(tensor: torch.Tensor): src_rank = rank assert target_tensor_spec is not None if tensor is None: - tensor = torch.empty(size=target_tensor_spec[0], dtype=target_tensor_spec[1], device=get_device_id()) + tensor = torch.empty( + size=target_tensor_spec[0], + dtype=target_tensor_spec[1], + device=get_device_id(), + ) if target_tensor_spec[2] is not None: tensor.tensor_model_parallel = target_tensor_spec[2] if target_tensor_spec[3] is not None: tensor.partition_dim = target_tensor_spec[3] - global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank) - torch.distributed.broadcast(tensor=tensor, src=global_rank, group=mpu.get_pipeline_model_parallel_group()) + global_rank = torch.distributed.get_global_rank( + group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank + ) + torch.distributed.broadcast( + tensor=tensor, src=global_rank, group=mpu.get_pipeline_model_parallel_group() + ) return tensor def broadcast_str_from_megatron_pp(obj: Any): obj_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object(object_list=obj_output, obj=obj, group=mpu.get_pipeline_model_parallel_group()) + torch.distributed.all_gather_object( + object_list=obj_output, obj=obj, group=mpu.get_pipeline_model_parallel_group() + ) src_rank = None target_obj = None @@ -645,12 +708,18 @@ def broadcast_str_from_megatron_pp(obj: Any): assert target_obj is not None, "No valid object found to broadcast." - global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank) + global_rank = torch.distributed.get_global_rank( + group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank + ) - obj_output = [None] * torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) + obj_output = [None] * torch.distributed.get_world_size( + group=mpu.get_pipeline_model_parallel_group() + ) obj_output[0] = target_obj torch.distributed.broadcast_object_list( - object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group() + object_list=obj_output, + src=global_rank, + group=mpu.get_pipeline_model_parallel_group(), ) return obj_output[0] @@ -690,9 +759,9 @@ def default_tp_concat_fn( num_key_value_heads = hf_config.vision_config.num_heads assert num_attention_heads % num_key_value_heads == 0 num_q_per_kv = num_attention_heads // num_key_value_heads - assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, ( - f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" - ) + assert ( + infer_params[0].shape[0] % (num_q_per_kv + 2) == 0 + ), f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] for infer_param in infer_params: @@ -710,7 +779,11 @@ def default_tp_concat_fn( q = torch.cat(q_lst, dim=0) k = torch.cat(k_lst, dim=0) v = torch.cat(v_lst, dim=0) - infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v] + infer_params = ( + torch.cat((q, k, v), dim=0) + if not convert_qkv_gate_up_by_simple_split + else [q, k, v] + ) elif ( layer_name_mapping.get("gate_proj_layer_name") in name @@ -726,14 +799,20 @@ def default_tp_concat_fn( up_lst.append(up) gate = torch.cat(gate_lst, dim=0) up = torch.cat(up_lst, dim=0) - infer_params = torch.cat((gate, up), dim=0) if not convert_qkv_gate_up_by_simple_split else [gate, up] + infer_params = ( + torch.cat((gate, up), dim=0) + if not convert_qkv_gate_up_by_simple_split + else [gate, up] + ) elif "mlp.experts.linear_fc2.weight" in name: # moe infer_params = torch.cat(infer_params, dim=1) else: # concat tensor - infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(train_params)) + infer_params = torch.cat( + infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(train_params) + ) return infer_params @@ -768,7 +847,11 @@ def tensor_generator(): # there is a bug in megatron GPTModel # decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in # state_dict(). for now we patch it by adding those keys to extra_keys. - extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] + extra_keys = [ + x + for x in model.state_dict().keys() + if "_extra_state" not in x and x not in existing_keys + ] for name in extra_keys: yield name, model.state_dict()[name].to(get_device_id()) @@ -780,13 +863,19 @@ def tensor_generator(): for idx, (name, _) in enumerate(model.named_parameters()): existing_keys.add(name) meta_info.append((pp_rank, scan_vpp_idx, idx, name)) - extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] + extra_keys = [ + x + for x in model.state_dict().keys() + if "_extra_state" not in x and x not in existing_keys + ] for name in extra_keys: meta_info.append((pp_rank, scan_vpp_idx, idx, name)) obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() torch.distributed.all_gather_object( - object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group() + object_list=obj_spec_output, + obj=meta_info, + group=mpu.get_pipeline_model_parallel_group(), ) layer_list_meta = [item for sublist in obj_spec_output for item in sublist] @@ -798,7 +887,8 @@ def tensor_generator(): import warnings warnings.warn( - "Current model sharing word and embedding weights, skip output layer conversion", stacklevel=2 + "Current model sharing word and embedding weights, skip output layer conversion", + stacklevel=2, ) continue @@ -807,7 +897,9 @@ def tensor_generator(): cur_name, cur_tensor = next(gen_func) except StopIteration: cur_name, cur_tensor = None, None - cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, transformer_config) + cur_name = normalize_model_name( + name, cur_pp_rank, scan_vpp_idx, transformer_config + ) else: cur_tensor, cur_name = None, None @@ -828,8 +920,13 @@ def tensor_generator(): name_prefix, local_expert_id = cur_name.split(".weight") local_expert_id = int(local_expert_id) - global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(ep_size)] - global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids] + global_expert_ids = [ + num_experts_per_rank * ep_rank + local_expert_id + for ep_rank in range(ep_size) + ] + global_expert_names = [ + f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids + ] for name, param in zip(global_expert_names, infer_params, strict=True): if etp_size > 1: @@ -851,7 +948,9 @@ def tensor_generator(): ) if not isinstance(merge_params, list): merge_params = [merge_params] - converted_names, converted_params = weight_converter.convert_param(name, merge_params) + converted_names, converted_params = weight_converter.convert_param( + name, merge_params + ) yield from zip(converted_names, converted_params, strict=True) continue @@ -862,8 +961,15 @@ def tensor_generator(): if all_gather_group_size <= 1: infer_params = [broad_pp_tensor] else: - infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)] - torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group()) + infer_params = [ + torch.empty_like(broad_pp_tensor) + for _ in range(all_gather_group_size) + ] + torch.distributed.all_gather( + infer_params, + broad_pp_tensor, + group=mpu.get_tensor_model_parallel_group(), + ) infer_params = default_tp_concat_fn( layer_name_mapping, cur_name, @@ -878,7 +984,9 @@ def tensor_generator(): if not isinstance(infer_params, list): infer_params = [infer_params] - converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params) + converted_names, converted_params = weight_converter.convert_param( + cur_name, infer_params + ) yield from zip(converted_names, converted_params, strict=True) @@ -916,14 +1024,20 @@ def get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConf # are not set, we will not enable uneven pipeline. All layers will be treated # as middle layers. num_layers_in_first_pipeline_stage = ( - 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage + 0 + if config.num_layers_in_first_pipeline_stage is None + else config.num_layers_in_first_pipeline_stage ) num_layers_in_last_pipeline_stage = ( - 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage + 0 + if config.num_layers_in_last_pipeline_stage is None + else config.num_layers_in_last_pipeline_stage ) middle_num_layers = ( - config.num_layers - num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage + config.num_layers + - num_layers_in_first_pipeline_stage + - num_layers_in_last_pipeline_stage ) if mpu.get_virtual_pipeline_model_parallel_world_size() is not None: @@ -945,7 +1059,9 @@ def get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConf else config.num_layers_in_last_pipeline_stage // vp_size ) - num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = middle_num_layers // vp_size + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = ( + middle_num_layers // vp_size + ) # First stage + middle stage + last stage total_virtual_chunks = ( @@ -962,22 +1078,31 @@ def get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConf vp_rank * total_virtual_chunks + num_layers_per_virtual_model_chunk_in_first_pipeline_stage + (pipeline_rank - 1) - * (num_layers_per_vritual_model_chunk_in_middle_pipeline_stage // middle_pipeline_stages) + * ( + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage + // middle_pipeline_stages + ) ) else: if middle_pipeline_stages > 0: - num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages + num_layers_per_pipeline_rank = ( + middle_num_layers // middle_pipeline_stages + ) else: num_layers_per_pipeline_rank = 0 middle_pipeline_rank = ( - pipeline_rank if config.num_layers_in_first_pipeline_stage is None else pipeline_rank - 1 + pipeline_rank + if config.num_layers_in_first_pipeline_stage is None + else pipeline_rank - 1 ) if pipeline_rank == 0: offset = 0 else: - offset = (middle_pipeline_rank * num_layers_per_pipeline_rank) + num_layers_in_first_pipeline_stage + offset = ( + middle_pipeline_rank * num_layers_per_pipeline_rank + ) + num_layers_in_first_pipeline_stage else: num_layers = config.num_layers @@ -989,23 +1114,33 @@ def get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConf if config.account_for_loss_in_pipeline_split: num_layers += 1 - num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size + num_layers_per_pipeline_rank = ( + num_layers // config.pipeline_model_parallel_size + ) if mpu.get_virtual_pipeline_model_parallel_world_size() is not None: vp_size = mpu.get_virtual_pipeline_model_parallel_world_size() num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size total_virtual_chunks = num_layers // vp_size - offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) + offset = vp_rank * total_virtual_chunks + ( + pipeline_rank * num_layers_per_virtual_rank + ) # Reduce the offset of embedding layer from the total layer number - if config.account_for_embedding_in_pipeline_split and not mpu.is_pipeline_first_stage(): + if ( + config.account_for_embedding_in_pipeline_split + and not mpu.is_pipeline_first_stage() + ): offset -= 1 else: offset = pipeline_rank * num_layers_per_pipeline_rank # Reduce the offset of embedding layer from the total layer number - if config.account_for_embedding_in_pipeline_split and not mpu.is_pipeline_first_stage(): + if ( + config.account_for_embedding_in_pipeline_split + and not mpu.is_pipeline_first_stage() + ): offset -= 1 else: offset = 0 diff --git a/Agent0/executor_train/verl/verl/utils/memory_buffer.py b/Agent0/executor_train/verl/verl/utils/memory_buffer.py index 9386f0d..7277226 100644 --- a/Agent0/executor_train/verl/verl/utils/memory_buffer.py +++ b/Agent0/executor_train/verl/verl/utils/memory_buffer.py @@ -29,14 +29,25 @@ class MemoryBuffer: memory. It must have a unique type to support this behavior. """ - def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Optional[torch.Tensor] = None): + def __init__( + self, + numel: int, + numel_padded: int, + dtype: torch.dtype, + source: Optional[torch.Tensor] = None, + ): self.numel = numel self.numel_padded = numel_padded self.dtype = dtype if source is not None: self.data = source else: - self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_name(), requires_grad=False) + self.data = torch.zeros( + self.numel_padded, + dtype=self.dtype, + device=get_device_name(), + requires_grad=False, + ) def zero(self): """Reset the buffer to zero.""" @@ -69,7 +80,9 @@ def get_weight_buffer_meta_from_module(module: nn.Module) -> dict[str, dict]: return weight_buffer_meta -def build_memory_buffer(weight_buffer_meta: dict[str, dict]) -> dict[torch.dtype, MemoryBuffer]: +def build_memory_buffer( + weight_buffer_meta: dict[str, dict], +) -> dict[torch.dtype, MemoryBuffer]: """Build the memory buffer given weight_buffer_meta Args: @@ -99,14 +112,18 @@ def build_memory_buffer(weight_buffer_meta: dict[str, dict]) -> dict[torch.dtype def build_memory_reference_from_module( - module: torch.nn.Module, memory_buffers: dict[torch.dtype, MemoryBuffer], maintain_weight=True + module: torch.nn.Module, + memory_buffers: dict[torch.dtype, MemoryBuffer], + maintain_weight=True, ): start_index = {} for dtype in memory_buffers: start_index[dtype] = 0 for name, param in sorted(module.named_parameters()): memory_buffer = memory_buffers[param.dtype] - buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype]) + buffer = memory_buffer.get( + shape=param.shape, start_index=start_index[param.dtype] + ) # need to increment start_index start_index[param.dtype] += calc_padded_numel(param.shape, param.dtype) if maintain_weight: @@ -114,7 +131,9 @@ def build_memory_reference_from_module( param.data = buffer -def build_memory_reference(weight_buffer_meta: dict[str, dict], memory_buffers: dict[torch.dtype, MemoryBuffer]): +def build_memory_reference( + weight_buffer_meta: dict[str, dict], memory_buffers: dict[torch.dtype, MemoryBuffer] +): """Build the memory references. The memory buffers are built using the build_memory_buffer API. This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta. @@ -202,7 +221,9 @@ def initialize_weight_buffer(self, weight_buffer_meta_pp: list[dict[str, dict]]) def build_memory_reference(self): for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp): - self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i]) + self._weight_buffers[i] = build_memory_reference( + weight_buffer_meta, self._memory_buffers[i] + ) self._named_parameters = self.transform_memory_param_fn(self._weight_buffers) @property diff --git a/Agent0/executor_train/verl/verl/utils/model.py b/Agent0/executor_train/verl/verl/utils/model.py index 04cc34f..ddf18bf 100644 --- a/Agent0/executor_train/verl/verl/utils/model.py +++ b/Agent0/executor_train/verl/verl/utils/model.py @@ -64,13 +64,17 @@ def update_model_config(module_config, override_config_kwargs): setattr(module_config, key, val) -def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> dict: +def get_huggingface_actor_config( + model_name: str, override_config_kwargs=None, trust_remote_code=False +) -> dict: if override_config_kwargs is None: override_config_kwargs = {} - assert isinstance(override_config_kwargs, dict), ( - f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" + assert isinstance( + override_config_kwargs, dict + ), f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" + module_config = AutoConfig.from_pretrained( + model_name, trust_remote_code=trust_remote_code ) - module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) update_model_config(module_config, override_config_kwargs) return module_config @@ -93,7 +97,9 @@ def get_generation_config( return None -def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module: +def create_huggingface_actor( + model_name: str, override_config_kwargs=None, automodel_kwargs=None +) -> nn.Module: """ Args: @@ -107,17 +113,23 @@ def create_huggingface_actor(model_name: str, override_config_kwargs=None, autom override_config_kwargs = {} if automodel_kwargs is None: automodel_kwargs = {} - assert isinstance(override_config_kwargs, dict), ( - f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" - ) + assert isinstance( + override_config_kwargs, dict + ), f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" module_config = get_huggingface_actor_config( - model_name, override_config_kwargs, trust_remote_code=automodel_kwargs.get("trust_remote_code", False) + model_name, + override_config_kwargs, + trust_remote_code=automodel_kwargs.get("trust_remote_code", False), + ) + module: nn.Module = AutoModelForCausalLM.from_config( + module_config, **automodel_kwargs ) - module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs) return module -def create_huggingface_critic(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module: +def create_huggingface_critic( + model_name: str, override_config_kwargs=None, automodel_kwargs=None +) -> nn.Module: """ Args: @@ -128,13 +140,16 @@ def create_huggingface_critic(model_name: str, override_config_kwargs=None, auto """ critic_module: nn.Module = create_huggingface_actor( - model_name, override_config_kwargs=override_config_kwargs, automodel_kwargs=automodel_kwargs + model_name, + override_config_kwargs=override_config_kwargs, + automodel_kwargs=automodel_kwargs, ) if automodel_kwargs is None: automodel_kwargs = {} torch_dtype = automodel_kwargs.get("torch_dtype", torch.float32) critic_module.lm_head = nn.Sequential( - nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), LambdaLayer(fn=squeeze) + nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), + LambdaLayer(fn=squeeze), ) return critic_module @@ -205,8 +220,12 @@ def create_random_mask( masks = torch.ones_like(input_ids, dtype=torch.int64) # TODO: we can make this faster for i in range(batch_size): - num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64) - num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64) + num_left_padding = np.random.randint( + low=0, high=max_left_padding + 1, dtype=np.int64 + ) + num_valid = np.random.randint( + low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64 + ) for index in range(num_left_padding): masks[i, index] = 0 @@ -225,11 +244,15 @@ def convert_weight_keys(state_dict: dict[str, torch.Tensor], model: PreTrainedMo if not hasattr(model, "_checkpoint_conversion_mapping"): return state_dict - reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()} + reverse_key_mapping = { + v: k for k, v in model._checkpoint_conversion_mapping.items() + } original_weights = {} for key, value in state_dict.items(): for pattern, replacement in reverse_key_mapping.items(): - replacement = replacement.lstrip("^") # strip off un-needed chars and patterns + replacement = replacement.lstrip( + "^" + ) # strip off un-needed chars and patterns replacement = re.sub(r"\(.*\)", "", replacement) key, n_replace = re.subn(pattern, replacement, key) # Early exit of the loop @@ -259,7 +282,9 @@ def check_exclude_modules(config, key: str) -> bool: return True elif key in config.exclude_modules: return True - elif any(key.endswith(f".{exclude_key}") for exclude_key in config.exclude_modules): + elif any( + key.endswith(f".{exclude_key}") for exclude_key in config.exclude_modules + ): return True return False @@ -282,7 +307,9 @@ def check_target_modules(config, key: str) -> bool: # this module is specified directly in target_modules target_module_found = True else: - target_module_found = any(key.endswith(f".{target_key}") for target_key in config.target_modules) + target_module_found = any( + key.endswith(f".{target_key}") for target_key in config.target_modules + ) layer_indexes = getattr(config, "layers_to_transform", None) layers_pattern = getattr(config, "layers_pattern", None) @@ -297,7 +324,11 @@ def check_target_modules(config, key: str) -> bool: if layers_pattern is None or len(layers_pattern) == 0: layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key) else: - layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern + layers_pattern = ( + [layers_pattern] + if isinstance(layers_pattern, str) + else layers_pattern + ) for pattern in layers_pattern: layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key) if layer_index is not None: @@ -315,7 +346,9 @@ def check_target_modules(config, key: str) -> bool: return target_module_found -def normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name="layers"): +def normalize_model_name( + name, pp_rank, vpp_rank, transformer_config, layer_name="layers" +): """ Transform the model name in each model_chunk in each pp stage into the name in inference engine """ @@ -355,13 +388,24 @@ def normalize_pp_vpp_params(params, num_hidden_layers, layer_name="layers"): for vpp_rank in range(vpp_size): for name, param in params[pp_rank][vpp_rank].items(): normalized_name = normalize_model_name( - name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name + name, + pp_rank, + vpp_rank, + pp_size, + vpp_size, + num_hidden_layers, + layer_name=layer_name, ) yield normalized_name, param def get_parallel_model_from_config( - config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False + config, + megatron_config, + pre_process=None, + post_process=None, + share_embeddings_and_output_weights=False, + value=False, ): from megatron.core import ModelParallelConfig @@ -378,7 +422,9 @@ def get_parallel_model_from_config( return model -def _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> type[nn.Module]: +def _get_parallel_model_architecture_from_config( + config: PretrainedConfig, value=False +) -> type[nn.Module]: architectures = getattr(config, "architectures", []) for arch in architectures: model_cls = ModelRegistry.load_model_cls(arch, value) @@ -398,7 +444,9 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): from verl.models.mcore.saver import _megatron_calc_global_rank - assert hasattr(model_config, "architectures"), "architectures cannot be empty when load weight!" + assert hasattr( + model_config, "architectures" + ), "architectures cannot be empty when load weight!" architectures = getattr(model_config, "architectures", []) local_cache_path = os.path.expanduser(local_cache_path) @@ -407,16 +455,24 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): print(f"start download from {config.model.path}") local_model_path = copy_to_local( - src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False) + src=config.model.path, + cache_dir=local_cache_path, + use_shm=config.model.get("use_shm", False), ) print("finish download") else: local_model_path = config.model.path print(f"load from local dir {local_model_path}") - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=mpu.get_context_parallel_rank()) + src_rank = _megatron_calc_global_rank( + tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=mpu.get_context_parallel_rank() + ) cpu_init_weights = lambda: torch.device("cpu") - init_context = init_empty_weights if torch.distributed.get_rank() != src_rank else cpu_init_weights + init_context = ( + init_empty_weights + if torch.distributed.get_rank() != src_rank + else cpu_init_weights + ) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") # TODO: to find a better way to load mistral7b-rm lm_head @@ -429,7 +485,9 @@ def _load_hf_model(config, model_config, is_value_model, local_cache_path): ) # use score head instead of lm_head state_dict = model.state_dict() state_dict["lm_head.weight"] = state_dict["score.weight"] - state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"][ + state_dict["model.embed_tokens.weight"] = state_dict[ + "model.embed_tokens.weight" + ][ :32000 ] # workaround, 32001 -> 32000 is_value_model = True @@ -451,7 +509,9 @@ def get_hf_model_path(config, local_cache_path="~/.cache/verl/rlhf"): from verl.utils.fs import copy_to_local local_model_path = copy_to_local( - src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False) + src=config.model.path, + cache_dir=local_cache_path, + use_shm=config.model.get("use_shm", False), ) else: local_model_path = config.model.path @@ -459,7 +519,12 @@ def get_hf_model_path(config, local_cache_path="~/.cache/verl/rlhf"): def load_megatron_model_weights( - config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf" + config, + model_config, + parallel_model, + params_dtype, + is_value_model=False, + local_cache_path="~/.cache/verl/rlhf", ): """Load weights for verl customized model.""" architectures, model, state_dict, is_value_model = _load_hf_model( @@ -484,10 +549,17 @@ def load_megatron_model_weights( def load_megatron_gptmodel_weights( - config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf" + config, + model_config, + parallel_model, + params_dtype, + is_value_model=False, + local_cache_path="~/.cache/verl/rlhf", ): """Load weights for mcore GPT model.""" - _, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model, local_cache_path) + _, model, state_dict, is_value_model = _load_hf_model( + config, model_config, is_value_model, local_cache_path + ) from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel @@ -502,7 +574,9 @@ def load_megatron_gptmodel_weights( # pad input_ids_rmpad, cu_seqlens and max_seqlen_in_batch to be divisible by tp -def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size): +def pad_packed_inputs( + unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size +): """pad the tokens such that the total length is a multiple of size. This function is useful when applying sequence parallel and context parallel @@ -527,7 +601,9 @@ def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batc elif unpad_tokens.ndim == 2: unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size)) else: - raise NotImplementedError(f"Padding dim {unpad_tokens.ndim()} is not supported") + raise NotImplementedError( + f"Padding dim {unpad_tokens.ndim()} is not supported" + ) cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_size + cu_seqlens[-1]) max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size) @@ -555,18 +631,29 @@ def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=Fal def get_parallel_gptmodel_from_config( - tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False + tfconfig, + hf_config, + pre_process=None, + post_process=None, + share_embeddings_and_output_weights=False, + value=False, ): from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.models.gpt.gpt_model import GPTModel use_te = True assert tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - transformer_layer_spec = get_gpt_decoder_block_spec(tfconfig, use_transformer_engine=use_te) + transformer_layer_spec = get_gpt_decoder_block_spec( + tfconfig, use_transformer_engine=use_te + ) rope_scaling_args = {} if hf_config.rope_scaling is not None: - assert hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" - rope_scaling_args["seq_len_interpolation_factor"] = hf_config.rope_scaling["factor"] + assert ( + hf_config.rope_scaling["type"] == "linear" + ), "only linear scaling is supported for now" + rope_scaling_args["seq_len_interpolation_factor"] = hf_config.rope_scaling[ + "factor" + ] parallel_model = GPTModel( config=tfconfig, transformer_layer_spec=transformer_layer_spec, @@ -600,18 +687,24 @@ def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: if isinstance(self.pretrained_model, PreTrainedModel): self.pretrained_model.tie_weights() - def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + def get_input_embeddings( + self: "AutoModelForCausalLMWithValueHead", + ) -> torch.nn.Module: if isinstance(self.pretrained_model, PreTrainedModel): return self.pretrained_model.get_input_embeddings() - def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + def get_output_embeddings( + self: "AutoModelForCausalLMWithValueHead", + ) -> torch.nn.Module: if isinstance(self.pretrained_model, PreTrainedModel): return self.pretrained_model.get_output_embeddings() def can_generate(self): return False - ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] + ignore_modules = [ + name for name, _ in model.named_parameters() if "pretrained_model" in name + ] model._keys_to_ignore_on_save = ignore_modules model.tie_weights = MethodType(tie_weights, model) model.get_input_embeddings = MethodType(get_input_embeddings, model) @@ -621,7 +714,11 @@ def can_generate(self): def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code): - from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq + from transformers import ( + AutoModelForCausalLM, + AutoModelForTokenClassification, + AutoModelForVision2Seq, + ) try: model = AutoModelForTokenClassification.from_pretrained( diff --git a/Agent0/executor_train/verl/verl/utils/profiler/__init__.py b/Agent0/executor_train/verl/verl/utils/profiler/__init__.py index 2242c24..da7b50e 100644 --- a/Agent0/executor_train/verl/verl/utils/profiler/__init__.py +++ b/Agent0/executor_train/verl/verl/utils/profiler/__init__.py @@ -19,10 +19,20 @@ if is_nvtx_available(): from .nvtx_profile import NsightSystemsProfiler as DistProfiler - from .nvtx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer + from .nvtx_profile import ( + mark_annotate, + mark_end_range, + mark_start_range, + marked_timer, + ) elif is_npu_available: from .mstx_profile import NPUProfiler as DistProfiler - from .mstx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer + from .mstx_profile import ( + mark_annotate, + mark_end_range, + mark_start_range, + marked_timer, + ) else: from .performance import marked_timer from .profile import DistProfiler, mark_annotate, mark_end_range, mark_start_range diff --git a/Agent0/executor_train/verl/verl/utils/profiler/config.py b/Agent0/executor_train/verl/verl/utils/profiler/config.py index 8acf075..b355a19 100644 --- a/Agent0/executor_train/verl/verl/utils/profiler/config.py +++ b/Agent0/executor_train/verl/verl/utils/profiler/config.py @@ -52,6 +52,6 @@ def intersect(self, other: "ProfilerConfig") -> "ProfilerConfig": def __post_init__(self) -> None: """config validation logics go here""" - assert isinstance(self.ranks, set | list | tuple), ( - f"Profiler ranks must be of type list, got {type(self.ranks)}" - ) + assert isinstance( + self.ranks, set | list | tuple + ), f"Profiler ranks must be of type list, got {type(self.ranks)}" diff --git a/Agent0/executor_train/verl/verl/utils/profiler/mstx_profile.py b/Agent0/executor_train/verl/verl/utils/profiler/mstx_profile.py index c5c35ce..ff4839e 100644 --- a/Agent0/executor_train/verl/verl/utils/profiler/mstx_profile.py +++ b/Agent0/executor_train/verl/verl/utils/profiler/mstx_profile.py @@ -81,7 +81,9 @@ def marked_timer(name: str, timing_raw: dict[str, float], **kwargs): mark_end_range(mark_range) -def get_npu_profiler(option: DictConfig, role: Optional[str] = None, profile_step: Optional[str] = None): +def get_npu_profiler( + option: DictConfig, role: Optional[str] = None, profile_step: Optional[str] = None +): """Generate and return an NPU profiler object. Args: @@ -101,7 +103,9 @@ def get_npu_profiler(option: DictConfig, role: Optional[str] = None, profile_ste elif option.level == "level2": profile_level = torch_npu.profiler.ProfilerLevel.Level2 else: - raise ValueError(f"level only supports level0, 1, 2, and level_none, but gets {option.level}") + raise ValueError( + f"level only supports level0, 1, 2, and level_none, but gets {option.level}" + ) profile_save_path = option.save_path if profile_step: @@ -129,7 +133,9 @@ def get_npu_profiler(option: DictConfig, role: Optional[str] = None, profile_ste record_shapes=option.record_shapes, profile_memory=option.with_memory, activities=activites, - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_save_path, analyse_flag=option.analysis), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler( + profile_save_path, analyse_flag=option.analysis + ), experimental_config=experimental_config, ) return prof @@ -167,7 +173,9 @@ def start(self, **kwargs): if self.this_rank and self.profile_option is not None: self.this_step = True if not self.discrete and NPUProfiler._define_count == 0: - self.profile_npu = get_npu_profiler(option=self.profile_option, role=role, profile_step=profile_step) + self.profile_npu = get_npu_profiler( + option=self.profile_option, role=role, profile_step=profile_step + ) self.profile_npu.start() NPUProfiler._define_count += 1 @@ -180,7 +188,9 @@ def stop(self): NPUProfiler._define_count -= 1 @staticmethod - def annotate(message: Optional[str] = None, role: Optional[str] = None, **kwargs) -> Callable: + def annotate( + message: Optional[str] = None, role: Optional[str] = None, **kwargs + ) -> Callable: """Decorate a Worker member function to profile the current rank in the current training step. Requires the target function to be a member function of a Worker, @@ -200,7 +210,9 @@ def wrapper(self, *args, **kwargs): if self.profiler.this_step and self.profile_option is not None: if self.profiler.discrete: - profile_npu = get_npu_profiler(option=self.profile_option, role=role) + profile_npu = get_npu_profiler( + option=self.profile_option, role=role + ) profile_npu.start() mark_range = mark_start_range(message=profile_name) diff --git a/Agent0/executor_train/verl/verl/utils/profiler/nvtx_profile.py b/Agent0/executor_train/verl/verl/utils/profiler/nvtx_profile.py index 9ebce37..90d2116 100644 --- a/Agent0/executor_train/verl/verl/utils/profiler/nvtx_profile.py +++ b/Agent0/executor_train/verl/verl/utils/profiler/nvtx_profile.py @@ -41,7 +41,9 @@ def mark_start_range( category (str, optional): The category of the range. Defaults to None. """ - return nvtx.start_range(message=message, color=color, domain=domain, category=category) + return nvtx.start_range( + message=message, color=color, domain=domain, category=category + ) def mark_end_range(range_id: str) -> None: @@ -75,7 +77,9 @@ def mark_annotate( def decorator(func): profile_message = message or func.__name__ - return nvtx.annotate(profile_message, color=color, domain=domain, category=category)(func) + return nvtx.annotate( + profile_message, color=color, domain=domain, category=category + )(func) return decorator @@ -103,7 +107,9 @@ def marked_timer( Yields: None: This is a context manager that yields control back to the code block. """ - mark_range = mark_start_range(message=name, color=color, domain=domain, category=category) + mark_range = mark_start_range( + message=name, color=color, domain=domain, category=category + ) from .performance import _timer yield from _timer(name, timing_raw) @@ -175,7 +181,12 @@ def wrapper(self, *args, **kwargs): if self.profiler.this_step: if self.profiler.discrete: torch.cuda.profiler.start() - mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category) + mark_range = mark_start_range( + message=profile_name, + color=color, + domain=domain, + category=category, + ) result = func(self, *args, **kwargs) diff --git a/Agent0/executor_train/verl/verl/utils/profiler/performance.py b/Agent0/executor_train/verl/verl/utils/profiler/performance.py index 8991896..59948bf 100644 --- a/Agent0/executor_train/verl/verl/utils/profiler/performance.py +++ b/Agent0/executor_train/verl/verl/utils/profiler/performance.py @@ -44,7 +44,9 @@ def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> tuple[str]: return mem_allocated, mem_reserved, mem_used, mem_total -def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): +def log_gpu_memory_usage( + head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0 +): """Log GPU memory usage information. Args: @@ -77,7 +79,13 @@ class GPUMemoryLogger(DecoratorLoggerBase): ... return """ - def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True): + def __init__( + self, + role: str, + logger: logging.Logger = None, + level=logging.DEBUG, + log_only_rank_0: bool = True, + ): if dist.is_initialized() and dist.get_world_size() > 1: rank = dist.get_rank() else: diff --git a/Agent0/executor_train/verl/verl/utils/profiler/profile.py b/Agent0/executor_train/verl/verl/utils/profiler/profile.py index 4e7ce4f..1baf7ca 100644 --- a/Agent0/executor_train/verl/verl/utils/profiler/profile.py +++ b/Agent0/executor_train/verl/verl/utils/profiler/profile.py @@ -70,11 +70,15 @@ def _validate(self): if self.config.profile_ranks is None: print("[WARNING] Profile ranks is not set, default to rank 0") self.config.profile_ranks = [0] - assert self.config.step_start >= 0, "[ERROR] Profile step start must be greater than 0" - assert self.config.step_end >= 0, "[ERROR] Profile step end must be greater than 0" - assert self.config.step_start < self.config.step_end, ( - "[ERROR] Profile step start must be less than step end" - ) + assert ( + self.config.step_start >= 0 + ), "[ERROR] Profile step start must be greater than 0" + assert ( + self.config.step_end >= 0 + ), "[ERROR] Profile step end must be greater than 0" + assert ( + self.config.step_start < self.config.step_end + ), "[ERROR] Profile step start must be less than step end" def check(self): return self.prof is not None and not self.skip_prof @@ -98,7 +102,9 @@ def save(self): if not os.path.exists(self.config.save_path): os.makedirs(self.config.save_path) save_file_name = f"/prof_start_{self.config.step_start}_end_{self.config.step_end}_rank_{self.rank}.json" - print(f"[Profiler] Saving trace to {self.config.save_path + save_file_name}") + print( + f"[Profiler] Saving trace to {self.config.save_path + save_file_name}" + ) self.prof.export_chrome_trace(self.config.save_path + save_file_name) self.skip_prof = True self.saved = True diff --git a/Agent0/executor_train/verl/verl/utils/py_functional.py b/Agent0/executor_train/verl/verl/utils/py_functional.py index 1ea02ef..22fefec 100644 --- a/Agent0/executor_train/verl/verl/utils/py_functional.py +++ b/Agent0/executor_train/verl/verl/utils/py_functional.py @@ -27,7 +27,12 @@ # --- Top-level helper for multiprocessing timeout --- # This function MUST be defined at the top level to be pickleable -def _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: tuple, kwargs: dict[str, Any]): +def _mp_target_wrapper( + target_func: Callable, + mp_queue: multiprocessing.Queue, + args: tuple, + kwargs: dict[str, Any], +): """ Internal wrapper function executed in the child process. Calls the original target function and puts the result or exception into the queue. @@ -44,7 +49,14 @@ def _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, a mp_queue.put((False, e)) # Indicate failure and put exception except (pickle.PicklingError, TypeError): # Fallback if the original exception cannot be pickled - mp_queue.put((False, RuntimeError(f"Original exception type {type(e).__name__} not pickleable: {e}"))) + mp_queue.put( + ( + False, + RuntimeError( + f"Original exception type {type(e).__name__} not pickleable: {e}" + ), + ) + ) # Renamed the function from timeout to timeout_limit @@ -82,7 +94,9 @@ def decorator(func): def wrapper_signal(*args, **kwargs): def handler(signum, frame): # Update function name in error message if needed (optional but good practice) - raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (signal)!") + raise TimeoutError( + f"Function {func.__name__} timed out after {seconds} seconds (signal)!" + ) old_handler = signal.getsignal(signal.SIGALRM) signal.signal(signal.SIGALRM, handler) @@ -103,7 +117,9 @@ def handler(signum, frame): @wraps(func) def wrapper_mp(*args, **kwargs): q = multiprocessing.Queue(maxsize=1) - process = multiprocessing.Process(target=_mp_target_wrapper, args=(func, q, args, kwargs)) + process = multiprocessing.Process( + target=_mp_target_wrapper, args=(func, q, args, kwargs) + ) process.start() process.join(timeout=seconds) @@ -111,12 +127,18 @@ def wrapper_mp(*args, **kwargs): process.terminate() process.join(timeout=0.5) # Give it a moment to terminate if process.is_alive(): - print(f"Warning: Process {process.pid} did not terminate gracefully after timeout.") + print( + f"Warning: Process {process.pid} did not terminate gracefully after timeout." + ) # Update function name in error message if needed (optional but good practice) - raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!") + raise TimeoutError( + f"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!" + ) try: - success, result_or_exc = q.get(timeout=0.1) # Small timeout for queue read + success, result_or_exc = q.get( + timeout=0.1 + ) # Small timeout for queue read if success: return result_or_exc else: @@ -155,7 +177,9 @@ def union_two_dict(dict1: dict, dict2: dict): """ for key, val in dict2.items(): if key in dict1: - assert dict2[key] == dict1[key], f"{key} in meta_dict1 and meta_dict2 are not the same object" + assert ( + dict2[key] == dict1[key] + ), f"{key} in meta_dict1 and meta_dict2 are not the same object" dict1[key] = val return dict1 @@ -277,7 +301,11 @@ def convert_to_regular_types(obj): from omegaconf import DictConfig, ListConfig if isinstance(obj, ListConfig | DictConfig): - return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj) + return ( + {k: convert_to_regular_types(v) for k, v in obj.items()} + if isinstance(obj, DictConfig) + else list(obj) + ) elif isinstance(obj, list | tuple): return [convert_to_regular_types(x) for x in obj] elif isinstance(obj, dict): diff --git a/Agent0/executor_train/verl/verl/utils/ray_utils.py b/Agent0/executor_train/verl/verl/utils/ray_utils.py index a738c0f..1587b80 100644 --- a/Agent0/executor_train/verl/verl/utils/ray_utils.py +++ b/Agent0/executor_train/verl/verl/utils/ray_utils.py @@ -67,7 +67,9 @@ def put_data(index, data): max_workers = min(len(data_list), 16) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)] + data_list_f = [ + executor.submit(put_data, i, data) for i, data in enumerate(data_list) + ] res_lst = [] for future in concurrent.futures.as_completed(data_list_f): res_lst.append(future.result()) diff --git a/Agent0/executor_train/verl/verl/utils/rendezvous/ray_backend.py b/Agent0/executor_train/verl/verl/utils/rendezvous/ray_backend.py index d991181..b4bcd87 100644 --- a/Agent0/executor_train/verl/verl/utils/rendezvous/ray_backend.py +++ b/Agent0/executor_train/verl/verl/utils/rendezvous/ray_backend.py @@ -43,7 +43,11 @@ def get_nccl_id_store_by_name(name): def create_nccl_communicator_in_ray( - rank: int, world_size: int, group_name: str, max_retries: int = 100, interval_s: int = 5 + rank: int, + world_size: int, + group_name: str, + max_retries: int = 100, + interval_s: int = 5, ): if rank == 0: nccl_id = get_unique_id() @@ -69,5 +73,9 @@ def create_nccl_communicator_in_ray( rank=rank, ) return communicator - logging.info("failed to get nccl_id for %d time, sleep for %d seconds", i + 1, interval_s) + logging.info( + "failed to get nccl_id for %d time, sleep for %d seconds", + i + 1, + interval_s, + ) time.sleep(interval_s) diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/__init__.py b/Agent0/executor_train/verl/verl/utils/reward_score/__init__.py index b298d41..ecfc4b6 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/__init__.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/__init__.py @@ -77,7 +77,12 @@ def default_compute_score( # Pass the URL directly, ground_truth likely contains test cases here res = sandbox_fusion.compute_score( - sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, solution_str, ground_truth, continuous=True + sandbox_fusion_url, + concurrent_semaphore, + memory_limit_mb, + solution_str, + ground_truth, + continuous=True, ) else: # If no sandbox URL is provided, fall back to prime_code or raise error @@ -103,7 +108,9 @@ def default_compute_score( res = search_r1_like_qa_em.compute_score(solution_str, ground_truth) else: - raise NotImplementedError(f"Reward function is not implemented for {data_source=}") + raise NotImplementedError( + f"Reward function is not implemented for {data_source=}" + ) if isinstance(res, dict): return res @@ -127,7 +134,13 @@ def _default_compute_score( Legacy function API to be deprecated. Please use `default_compute_score` instead. """ return default_compute_score( - data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore, memory_limit_mb + data_source, + solution_str, + ground_truth, + extra_info, + sandbox_fusion_url, + concurrent_semaphore, + memory_limit_mb, ) diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/geo3k.py b/Agent0/executor_train/verl/verl/utils/reward_score/geo3k.py index 8a85087..644494a 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/geo3k.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/geo3k.py @@ -30,7 +30,12 @@ def acc_reward(predict_str: str, ground_truth: str, use_boxed: bool = True) -> f return 1.0 if grade_answer(answer, ground_truth) else 0.0 -def compute_score(predict_str: str, ground_truth: str, use_boxed: bool = True, format_score: float = 0.1) -> float: - return (1.0 - format_score) * acc_reward(predict_str, ground_truth, use_boxed) + format_score * format_reward( - predict_str - ) +def compute_score( + predict_str: str, + ground_truth: str, + use_boxed: bool = True, + format_score: float = 0.1, +) -> float: + return (1.0 - format_score) * acc_reward( + predict_str, ground_truth, use_boxed + ) + format_score * format_reward(predict_str) diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/gsm8k.py b/Agent0/executor_train/verl/verl/utils/reward_score/gsm8k.py index c2afafc..6860cc8 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/gsm8k.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/gsm8k.py @@ -41,7 +41,9 @@ def extract_solution(solution_str, method="strict"): return final_answer -def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): +def compute_score( + solution_str, ground_truth, method="strict", format_score=0.0, score=1.0 +): """The scoring function for GSM8k. Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/math_dapo.py b/Agent0/executor_train/verl/verl/utils/reward_score/math_dapo.py index 940500f..38904dd 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/math_dapo.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/math_dapo.py @@ -163,7 +163,10 @@ def normalize_final_answer(final_answer: str) -> str: def is_correct_minerva( - solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" + solution_str: str, + gt: str, + gt_need_extract: bool = False, + answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)", ) -> tuple[bool, str]: """Check if the solution is correct according to Minerva criteria. @@ -218,7 +221,10 @@ def is_correct_strict_box( def verify( - solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None + solution_str: str, + answer: str, + strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None, ) -> bool: """Verify if the solution is correct. @@ -257,10 +263,14 @@ def compute_score( Reward score (1.0 for correct, -1.0 for incorrect) """ # Limit solution length for efficiency - solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + solution_str = solution_str[ + -300: + ] # The longest answer in MATH-500 has 159 characters # Verify the solution - correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + correct, pred = verify( + solution_str, ground_truth, strict_box_verify, pause_tokens_index + ) reward = 1.0 if correct else -1.0 acc = correct diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/math_verify.py b/Agent0/executor_train/verl/verl/utils/reward_score/math_verify.py index c1ce7c1..94b24ec 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/math_verify.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/math_verify.py @@ -17,10 +17,14 @@ from math_verify.metric import math_metric from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig except ImportError: - print("To use Math-Verify, please install it first by running `pip install math-verify`.") + print( + "To use Math-Verify, please install it first by running `pip install math-verify`." + ) -def compute_score(model_output: str, ground_truth: str, timeout_score: float = 0) -> bool: +def compute_score( + model_output: str, ground_truth: str, timeout_score: float = 0 +) -> bool: verify_func = math_metric( gold_extraction_target=(LatexExtractionConfig(),), pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/__init__.py b/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/__init__.py index 214f99b..aea675d 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/__init__.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/__init__.py @@ -30,7 +30,9 @@ def compute_score(completion, test_cases, continuous=False): # Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped. try: - res, metadata = apps_check_correctness(in_outs=test_cases, generation=solution, timeout=5, debug=False) + res, metadata = apps_check_correctness( + in_outs=test_cases, generation=solution, timeout=5, debug=False + ) metadata = dict(enumerate(metadata))[0] success = all(map(lambda x: x is True, res)) if success: @@ -50,9 +52,13 @@ def compute_score(completion, test_cases, continuous=False): metadata_list = [] res_list = [] for test_case_id, test_case in enumerate(test_cases_list): - res, metadata = apps_check_correctness(in_outs=test_case, generation=solution, timeout=10, debug=False) + res, metadata = apps_check_correctness( + in_outs=test_case, generation=solution, timeout=10, debug=False + ) try: - metadata = dict(enumerate(metadata))[0] # metadata can be empty occasionally + metadata = dict(enumerate(metadata))[ + 0 + ] # metadata can be empty occasionally except Exception: metadata = {} metadata["test_case"] = {} diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/testing_util.py b/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/testing_util.py index 2f22325..ec0722f 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/testing_util.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/testing_util.py @@ -81,7 +81,9 @@ def combined_int_check(val): def clean_traceback(error_traceback): file_start = error_traceback.find('File ""') # print(file_start) - error_traceback = "Traceback (most recent call last):\n " + error_traceback[file_start:] + error_traceback = ( + "Traceback (most recent call last):\n " + error_traceback[file_start:] + ) return error_traceback @@ -147,7 +149,11 @@ def run_test(in_outs, test=None, debug=False, timeout=15): if isinstance(last_block, ast.If): condition = last_block.test if ast.unparse(condition).strip() == "__name__ == '__main__'": - test = ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) + test = ( + ast.unparse(astree.body[:-1]) + + "\n" + + ast.unparse(last_block.body) + ) except Exception: pass @@ -224,7 +230,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15): truncate_line_size = 300 // (raw_inputs.count("\n") + 1) raw_inputs = "\n".join( - [truncatefn(line, truncate_line_size) for line in raw_inputs.strip().split("\n")] + [ + truncatefn(line, truncate_line_size) + for line in raw_inputs.strip().split("\n") + ] ) raw_outputs = truncatefn(raw_outputs, 200) else: @@ -238,12 +247,16 @@ def run_test(in_outs, test=None, debug=False, timeout=15): pass try: if isinstance(in_outs["outputs"][index], dict): - in_outs["outputs"][index] = [{int(k): v for k, v in in_outs["outputs"][index].items()}] + in_outs["outputs"][index] = [ + {int(k): v for k, v in in_outs["outputs"][index].items()} + ] except Exception: pass try: if isinstance(in_outs["outputs"][index][0], dict): - in_outs["outputs"][index] = [{int(k): v for k, v in in_outs["outputs"][index][0].items()}] + in_outs["outputs"][index] = [ + {int(k): v for k, v in in_outs["outputs"][index][0].items()} + ] except Exception: pass @@ -267,13 +280,21 @@ def run_test(in_outs, test=None, debug=False, timeout=15): output = list(output) tmp_result = output == in_outs["outputs"][index] - if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]: - tmp_result = tmp_result or (output == in_outs["outputs"][index][0]) + if ( + isinstance(in_outs["outputs"][index], list) + and in_outs["outputs"][index] + ): + tmp_result = tmp_result or ( + output == in_outs["outputs"][index][0] + ) # ground truth sequences are not tuples try: if isinstance(output[0], tuple): - tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0]) + tmp_result = tmp_result or ( + [list(x) for x in output] + == in_outs["outputs"][index][0] + ) except Exception: pass results.append(tmp_result) @@ -292,7 +313,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15): error_traceback = traceback.format_exc() faulthandler.disable() if debug: - print(f"Standard input runtime error or time limit exceeded error = {e}") + print( + f"Standard input runtime error or time limit exceeded error = {e}" + ) results.append(-1) return results, { "error": repr(e), @@ -325,7 +348,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15): # runtime error or took too long signal.alarm(0) error_traceback = traceback.format_exc() - print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") + print( + f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}" + ) results.append(-1) return results, { "error": repr(e), @@ -352,7 +377,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15): continue if passed and debug: - print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}") + print( + f"==> output = {output}, test outputs = {in_outs['outputs'][index]}" + ) if custom_compare_(output, in_outs["outputs"][index]): tmp_result = True @@ -369,7 +396,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15): if isinstance(in_outs["outputs"][index], list): tmp_result = tmp_result or (output == in_outs["outputs"][index]) if isinstance(output[0], str): - tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index]) + tmp_result = tmp_result or ( + [e.strip() for e in output] == in_outs["outputs"][index] + ) except Exception as e: if debug: print(f"Failed check1 exception = {e}") @@ -388,8 +417,12 @@ def run_test(in_outs, test=None, debug=False, timeout=15): ] else: in_outs["outputs"][index] = in_outs["outputs"][index].split("\n") - in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index])) - in_outs["outputs"][index] = list(map(lambda x: x.strip(), in_outs["outputs"][index])) + in_outs["outputs"][index] = list( + filter(len, in_outs["outputs"][index]) + ) + in_outs["outputs"][index] = list( + map(lambda x: x.strip(), in_outs["outputs"][index]) + ) try: tmp_result = output == [in_outs["outputs"][index]] @@ -440,20 +473,25 @@ def run_test(in_outs, test=None, debug=False, timeout=15): try: all_ints = all( combined_int_check(e1) and combined_int_check(e2) - for e1, e2 in zip(output, in_outs["outputs"][index], strict=True) + for e1, e2 in zip( + output, in_outs["outputs"][index], strict=True + ) ) if not all_ints: if debug: print( [ combined_int_check(e1) and combined_int_check(e2) - for e1, e2 in zip(output, in_outs["outputs"][index], strict=True) + for e1, e2 in zip( + output, in_outs["outputs"][index], strict=True + ) ] ) output_float = [float(e) for e in output] gt_float = [float(e) for e in in_outs["outputs"][index]] tmp_result = tmp_result or ( - (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float) + (len(output_float) == len(gt_float)) + and np.allclose(output_float, gt_float) ) except Exception: pass @@ -465,13 +503,16 @@ def run_test(in_outs, test=None, debug=False, timeout=15): if isinstance(output[0], list): all_ints = all( combined_int_check(e1) and combined_int_check(e2) - for e1, e2 in zip(output[0], in_outs["outputs"][index], strict=True) + for e1, e2 in zip( + output[0], in_outs["outputs"][index], strict=True + ) ) if not all_ints: output_float = [float(e) for e in output[0]] gt_float = [float(e) for e in in_outs["outputs"][index][0]] tmp_result = tmp_result or ( - (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float) + (len(output_float) == len(gt_float)) + and np.allclose(output_float, gt_float) ) except Exception: pass @@ -615,10 +656,16 @@ def reliability_guard(maximum_memory_bytes=None): if maximum_memory_bytes is not None: import resource - resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) - resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit( + resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes) + ) + resource.setrlimit( + resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes) + ) if platform.uname().system != "Darwin": - resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) + resource.setrlimit( + resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes) + ) faulthandler.disable() diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/utils.py b/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/utils.py index 9123265..f6ab35e 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/utils.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/prime_code/utils.py @@ -28,7 +28,9 @@ def _temp_run(sample, generation, debug, result, metadata_list, timeout): sys.stdout = devnull sys.stderr = devnull try: - res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout) + res, metadata = run_test( + in_outs=sample, test=generation, debug=debug, timeout=timeout + ) result.append(res) metadata_list.append(metadata) except Exception: @@ -46,7 +48,10 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru manager = multiprocessing.Manager() result = manager.list() metadata_list = manager.list() - p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout)) + p = multiprocessing.Process( + target=_temp_run, + args=(in_outs, generation, debug, result, metadata_list, timeout), + ) p.start() p.join(timeout=timeout + 1) if p.is_alive(): diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/prime_math/__init__.py b/Agent0/executor_train/verl/verl/utils/reward_score/prime_math/__init__.py index 04fd146..82a4c86 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/prime_math/__init__.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/prime_math/__init__.py @@ -46,7 +46,10 @@ def _sympy_parse(expr: str): py_expr = expr.replace("^", "**") return sympy_parser.parse_expr( py_expr, - transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)), + transformations=( + sympy_parser.standard_transformations + + (sympy_parser.implicit_multiplication_application,) + ), ) @@ -277,12 +280,17 @@ def grade_answer(given_answer: str, ground_truth: str) -> bool: if ( len(ground_truth_elems) > 1 - and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]) + and ( + ground_truth_normalized[0] != given_normalized[0] + or ground_truth_normalized[-1] != given_normalized[-1] + ) or len(ground_truth_elems) != len(given_elems) ): is_correct = False else: - for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True): + for ground_truth_elem, given_elem in zip( + ground_truth_elems, given_elems, strict=True + ): if _is_frac(ground_truth_elem) and _is_frac(given_elem): # if fractions aren't reduced, then shouldn't be marked as correct # so, we don't want to allow sympy.simplify in this case @@ -297,7 +305,9 @@ def grade_answer(given_answer: str, ground_truth: str) -> bool: except Exception as e: # if there's an error, we'll just say it's not correct is_correct = False - print(f"Error: {e} from are_equal_under_sympy, {ground_truth_elem}, {given_elem}") + print( + f"Error: {e} from are_equal_under_sympy, {ground_truth_elem}, {given_elem}" + ) if not is_correct: break @@ -373,7 +383,19 @@ def match_answer(response): if dot_idx != -1: response = response[:dot_idx].strip() - for ans_marker in ["be ", "is ", "are ", "=", ": ", "get ", "be\n", "is\n", "are\n", ":\n", "get\n"]: + for ans_marker in [ + "be ", + "is ", + "are ", + "=", + ": ", + "get ", + "be\n", + "is\n", + "are\n", + ":\n", + "get\n", + ]: ans_idx = response.lower().rfind(ans_marker) if ans_idx != -1: is_matched = True @@ -381,7 +403,9 @@ def match_answer(response): if response.endswith("\n"): response = response[:-2] - is_matched = is_matched if any([c.isdigit() for c in response]) else False # answer must have a digit + is_matched = ( + is_matched if any([c.isdigit() for c in response]) else False + ) # answer must have a digit # Grade return is_matched, response @@ -401,7 +425,11 @@ def compute_score(model_output: str, ground_truth: str) -> bool: if "\pi" in extracted_model_output or "\pi" in ground_truth: equivs = [] for pi in [math.pi, 3.14]: - equivs.append(math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi)) + equivs.append( + math_equal( + extracted_model_output, ground_truth, timeout=True, pi=pi + ) + ) is_correct = any(equivs) else: is_correct = math_equal(extracted_model_output, ground_truth, timeout=True) diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/prime_math/grader.py b/Agent0/executor_train/verl/verl/utils/reward_score/prime_math/grader.py index d060584..403e224 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/prime_math/grader.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/prime_math/grader.py @@ -125,7 +125,8 @@ def normalize(answer, pi) -> str: # checking if answer is % or \\% and removing % if isinstance(answer, str) and ( - bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer)) + bool(re.match(r"^\d+(\.\d+)?%$", answer)) + or bool(re.match(r"^\d+(\.\d+)?\\%$", answer)) ): return answer.replace("\\%", "").replace("%", "") @@ -188,7 +189,9 @@ def math_equal( prediction = normalize(prediction, pi) reference = normalize(reference, pi) - if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases + if ( + isinstance(prediction, str) and len(prediction) > 1000 + ): # handling weird corner-cases prediction = prediction[:1000] # 0. string comparison @@ -203,7 +206,11 @@ def math_equal( prediction = is_digit(prediction)[1] reference = is_digit(reference)[1] # number questions - gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + gt_result = ( + [reference / 100, reference, reference * 100] + if include_percentage + else [reference] + ) for item in gt_result: try: if isclose(item, prediction, rel_tol=tolerance): @@ -225,8 +232,14 @@ def math_equal( prediction = format_intervals(prediction) pred_str, ref_str = prediction, reference - if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or ( - prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") + if ( + prediction.startswith("[") + and prediction.endswith("]") + and not reference.startswith("(") + ) or ( + prediction.startswith("(") + and prediction.endswith(")") + and not reference.startswith("[") ): pred_str = pred_str.strip("[]()") ref_str = ref_str.strip("[]()") @@ -263,7 +276,9 @@ def math_equal( return bool( all( [ - math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) + math_equal( + pred_parts[i], ref_parts[i], include_percentage, tolerance + ) for i in range(len(pred_parts)) ] ) @@ -295,7 +310,11 @@ def math_equal( return True except Exception: pass - elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"): + elif ( + "\begin{pmatrix}" in reference + and prediction.startswith("[") + and prediction.endswith("]") + ): if isinstance(eval(prediction), list): try: pred_matrix = eval(prediction) @@ -307,11 +326,15 @@ def math_equal( .rstrip("\end{pmatrix}") ) # noqa: B005 ref_matrix_items = ref_matrix_items.split("\\") - ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items] + ref_matrix_items = [ + row.split("&") if "&" in row else row for row in ref_matrix_items + ] if len(pred_matrix) == len(ref_matrix_items) and all( [ math_equal(pred, ref, include_percentage, tolerance) - for ref, pred in zip(ref_matrix_items, pred_matrix, strict=False) + for ref, pred in zip( + ref_matrix_items, pred_matrix, strict=False + ) ] ): return True diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/sandbox_fusion/__init__.py b/Agent0/executor_train/verl/verl/utils/reward_score/sandbox_fusion/__init__.py index cd18498..af4220a 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/sandbox_fusion/__init__.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/sandbox_fusion/__init__.py @@ -26,7 +26,13 @@ def compute_score( - sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, completion, test_cases, continuous=False, timeout=10 + sandbox_fusion_url, + concurrent_semaphore, + memory_limit_mb, + completion, + test_cases, + continuous=False, + timeout=10, ): """ Computes the code score using the remote sandbox API. @@ -70,7 +76,9 @@ def compute_score( if not test_cases or "inputs" not in test_cases or "outputs" not in test_cases: logger.error("Invalid test_cases structure.") - return 0.0, [{"error": "Invalid test_cases structure (missing inputs/outputs)"}] + return 0.0, [ + {"error": "Invalid test_cases structure (missing inputs/outputs)"} + ] # Check all test cases # Note: The return value of check_correctness might need adaptation here @@ -111,7 +119,13 @@ def compute_score( traceback.print_exc() score = 0.0 # Try to return partial metadata if available, otherwise return error info - final_metadata = metadata_list if "metadata_list" in locals() else [{"error": f"Unhandled exception: {e}"}] + final_metadata = ( + metadata_list + if "metadata_list" in locals() + else [{"error": f"Unhandled exception: {e}"}] + ) # Ensure float and list are returned - return float(score), final_metadata if isinstance(final_metadata, list) else [final_metadata] + return float(score), ( + final_metadata if isinstance(final_metadata, list) else [final_metadata] + ) diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/sandbox_fusion/utils.py b/Agent0/executor_train/verl/verl/utils/reward_score/sandbox_fusion/utils.py index d2154ca..76c22c6 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/sandbox_fusion/utils.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/sandbox_fusion/utils.py @@ -139,8 +139,12 @@ def call_sandbox_api( # Calculate increasing delay (e.g., 1s, 2s, 4s, ...) or (1s, 2s, 3s, ...) # Simple linear increase: delay = INITIAL_RETRY_DELAY * (attempt + 1) # Exponential backoff: delay = INITIAL_RETRY_DELAY * (2 ** attempt) - delay = INITIAL_RETRY_DELAY * (attempt + 1) # Using linear increase for simplicity - logger.info(f"{log_prefix}Retrying after {delay} seconds...") # <-- Use internal log_prefix + delay = INITIAL_RETRY_DELAY * ( + attempt + 1 + ) # Using linear increase for simplicity + logger.info( + f"{log_prefix}Retrying after {delay} seconds..." + ) # <-- Use internal log_prefix time.sleep(delay) continue # Go to the next retry attempt @@ -154,21 +158,31 @@ def call_sandbox_api( return response.json(), None except requests.exceptions.RequestException as e: - last_error = f"{log_prefix}API Request Error: {e}" # <-- Use internal log_prefix + last_error = ( + f"{log_prefix}API Request Error: {e}" # <-- Use internal log_prefix + ) break # Exit retry loop on non-504 request errors except json.JSONDecodeError as e: raw_response_text = response.text if "response" in locals() else "N/A" last_error = f"{log_prefix}API Response JSON Decode Error: {e}" # <-- Use internal log_prefix break # Exit retry loop on JSON decode errors except Exception as e: - last_error = f"{log_prefix}Unexpected Error: {e}" # <-- Use internal log_prefix + last_error = ( + f"{log_prefix}Unexpected Error: {e}" # <-- Use internal log_prefix + ) break # Exit retry loop on other unexpected errors # If loop finishes without returning success, return the last recorded error - logger.error(f"{log_prefix}Sandbox API call failed. Last error: {last_error}") # <-- Use internal log_prefix + logger.error( + f"{log_prefix}Sandbox API call failed. Last error: {last_error}" + ) # <-- Use internal log_prefix # Return the error message without the prefix, as the caller doesn't need the internal ID # Ensure API call failure returns error message, leading to -1 in check_correctness - return None, last_error.replace(log_prefix, "API Call Failed: ") if last_error else "API Call Failed after retries" + return None, ( + last_error.replace(log_prefix, "API Call Failed: ") + if last_error + else "API Call Failed after retries" + ) def _process_single_case( @@ -344,7 +358,9 @@ def _execute_user_function(): result_status = -1 # API request itself failed (includes timeout after retries) logger.error(f"Case {case_index}: API error occurred: {error_msg}") # Log code and input only on error for brevity - generation_to_log = generation[:200] + "..." if len(generation) > 200 else generation + generation_to_log = ( + generation[:200] + "..." if len(generation) > 200 else generation + ) logger.error(f"Case {case_index}: code: {generation_to_log}") logger.error(f"Case {case_index}: input: {str(stdin_data)}") elif api_response: @@ -384,7 +400,10 @@ def _execute_user_function(): # Compile failed or timed out is_compile_error = compile_result and ( metadata["compile_status"] in ["Error", "TimeLimitExceeded"] - or (metadata["compile_status"] == "Finished" and compile_result.get("return_code") != 0) + or ( + metadata["compile_status"] == "Finished" + and compile_result.get("return_code") != 0 + ) ) if is_compile_error: # Differentiate between compile_error and compile_timeout based on specific status @@ -399,7 +418,10 @@ def _execute_user_function(): is_runtime_error = ( metadata["run_status"] == "TimeLimitExceeded" or metadata["run_status"] == "Error" - or (metadata["run_status"] == "Finished" and run_result.get("return_code") != 0) + or ( + metadata["run_status"] == "Finished" + and run_result.get("return_code") != 0 + ) ) if is_runtime_error: if metadata["run_status"] == "TimeLimitExceeded": @@ -410,18 +432,24 @@ def _execute_user_function(): result_status = -2 else: # Other Failed status with run_result, classify as unknown failure - logger.warning(f"Unknown run_status '{metadata['run_status']}' or state within Failed API status.") + logger.warning( + f"Unknown run_status '{metadata['run_status']}' or state within Failed API status." + ) metadata["status"] = "unknown_failure" result_status = -1 # Default to -1 else: # Status is Failed but neither a clear compile error nor run_result exists - logger.warning("API status Failed but cannot determine specific error type (compile/run).") + logger.warning( + "API status Failed but cannot determine specific error type (compile/run)." + ) metadata["status"] = "unknown_failure_state" result_status = -1 # Default to -1 elif api_status == "Success": # Run completed successfully, now check the answer if run_result and metadata["run_status"] == "Finished": - actual_output = metadata["stdout"] if metadata["stdout"] is not None else "" + actual_output = ( + metadata["stdout"] if metadata["stdout"] is not None else "" + ) # Note: Output might contain trailing newlines, need normalization if str(actual_output).rstrip("\n") == str(expected_output).rstrip("\n"): result_status = True @@ -441,7 +469,9 @@ def _execute_user_function(): else: # api_response is None and no error_msg (Should not happen with current call_sandbox_api logic) metadata["status"] = "unknown_api_state" result_status = -1 - logger.error(f"Case {case_index}: Unknown API state (no response and no error message).") + logger.error( + f"Case {case_index}: Unknown API state (no response and no error message)." + ) return result_status, metadata @@ -491,14 +521,21 @@ def check_correctness( return [], [] if len(inputs) != len(expected_outputs): - logger.warning(f"Mismatch between number of inputs ({len(inputs)}) and outputs ({len(expected_outputs)}).") + logger.warning( + f"Mismatch between number of inputs ({len(inputs)}) and outputs ({len(expected_outputs)})." + ) # Return error based on the number of inputs provided - return [-1] * num_cases, [{"error": "Input/output count mismatch", "case_index": i} for i in range(num_cases)] + return [-1] * num_cases, [ + {"error": "Input/output count mismatch", "case_index": i} + for i in range(num_cases) + ] first_compile_error_index = -1 # max_workers is limited by sandbox_fusion_max_concurrent from concurrent_semaphore - with concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5)) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=max(32, os.cpu_count() * 5) + ) as executor: # Submit all tasks, passing the concurrent_semaphore to _process_single_case future_to_index = { executor.submit( @@ -527,7 +564,10 @@ def check_correctness( # Check for compile error (-4) if result_status == -4: - if first_compile_error_index == -1 or index < first_compile_error_index: + if ( + first_compile_error_index == -1 + or index < first_compile_error_index + ): first_compile_error_index = index # Optimization: could potentially cancel futures for index > first_compile_error_index # However, cancellation is not guaranteed. Post-processing is safer. @@ -554,7 +594,9 @@ def check_correctness( if results[i] != -4: # Avoid overwriting if it somehow already got -4 results[i] = -4 # Update or create metadata for skipped cases due to compile error - if metadata_list[i] is None: # If future failed before returning metadata + if ( + metadata_list[i] is None + ): # If future failed before returning metadata metadata_list[i] = { "case_index": i, "input": str(inputs[i]), diff --git a/Agent0/executor_train/verl/verl/utils/reward_score/search_r1_like_qa_em.py b/Agent0/executor_train/verl/verl/utils/reward_score/search_r1_like_qa_em.py index 56782fc..40a36e7 100644 --- a/Agent0/executor_train/verl/verl/utils/reward_score/search_r1_like_qa_em.py +++ b/Agent0/executor_train/verl/verl/utils/reward_score/search_r1_like_qa_em.py @@ -93,7 +93,9 @@ def count_answer_tags(text): return opening_tags, closing_tags -def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): +def compute_score( + solution_str, ground_truth, method="strict", format_score=0.0, score=1.0 +): """The scoring function for exact match (EM). Args: @@ -128,7 +130,9 @@ def compute_score(solution_str, ground_truth, method="strict", format_score=0.0, return format_score -def compute_score_subem(solution_str, ground_truth, method="strict", format_score=0.0, score=1.0): +def compute_score_subem( + solution_str, ground_truth, method="strict", format_score=0.0, score=1.0 +): """The scoring function for substring exact match (EM). Args: diff --git a/Agent0/executor_train/verl/verl/utils/rollout_trace.py b/Agent0/executor_train/verl/verl/utils/rollout_trace.py index 114006d..4bee639 100644 --- a/Agent0/executor_train/verl/verl/utils/rollout_trace.py +++ b/Agent0/executor_train/verl/verl/utils/rollout_trace.py @@ -37,7 +37,13 @@ def get_instance(cls) -> "RolloutTraceConfig": return cls._instance @classmethod - def init(cls, project_name: str, experiment_name: str, backend: str, token2text: bool = False): + def init( + cls, + project_name: str, + experiment_name: str, + backend: str, + token2text: bool = False, + ): config = cls.get_instance() config.backend = backend config.token2text = token2text @@ -123,15 +129,23 @@ async def async_wrapper(self, *args, **kwargs): del inputs["self"] async def add_token2text(self, result): - if hasattr(result, "prompt_ids") and hasattr(self, "tokenizer") and hasattr(self.tokenizer, "decode"): + if ( + hasattr(result, "prompt_ids") + and hasattr(self, "tokenizer") + and hasattr(self.tokenizer, "decode") + ): _result = [result] loop = asyncio.get_running_loop() if hasattr(result, "prompt_ids"): - prompt_text = await loop.run_in_executor(None, self.tokenizer.decode, result.prompt_ids) + prompt_text = await loop.run_in_executor( + None, self.tokenizer.decode, result.prompt_ids + ) _result.append(prompt_text) if hasattr(result, "response_ids"): - response_text = await loop.run_in_executor(None, self.tokenizer.decode, result.response_ids) + response_text = await loop.run_in_executor( + None, self.tokenizer.decode, result.response_ids + ) _result.append(response_text) return _result return result @@ -141,7 +155,9 @@ async def add_token2text(self, result): from weave.trace.context import call_context cur_attributes = {**call_context.call_attributes.get()} - call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes) + call = tracer.create_call( + op=func.__qualname__, inputs=inputs, attributes=cur_attributes + ) try: result = await func(self, *args, **kwargs) @@ -177,7 +193,9 @@ def wrapper(self, *args, **kwargs): from weave.trace.context import call_context cur_attributes = {**call_context.call_attributes.get()} - call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes) + call = tracer.create_call( + op=func.__qualname__, inputs=inputs, attributes=cur_attributes + ) try: result = func(self, *args, **kwargs) tracer.finish_call(call, output=result) diff --git a/Agent0/executor_train/verl/verl/utils/seqlen_balancing.py b/Agent0/executor_train/verl/verl/utils/seqlen_balancing.py index 4938e8f..2e8e493 100644 --- a/Agent0/executor_train/verl/verl/utils/seqlen_balancing.py +++ b/Agent0/executor_train/verl/verl/utils/seqlen_balancing.py @@ -97,7 +97,9 @@ def __repr__(self) -> str: sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)]) states_pq = [] if equal_size: - assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0" + assert ( + len(seqlen_list) % k_partitions == 0 + ), f"{len(seqlen_list)} % {k_partitions} != 0" for offset in range(0, len(sorted_seqlen_list), k_partitions): items = [] for i in range(k_partitions): @@ -119,9 +121,9 @@ def __repr__(self) -> str: partitions = final_state.get_partitions() if equal_size: for i, partition in enumerate(partitions): - assert len(partition) * k_partitions == len(seqlen_list), ( - f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" - ) + assert len(partition) * k_partitions == len( + seqlen_list + ), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" return partitions @@ -139,13 +141,15 @@ def greedy_partition(seqlen_list: list[int], k_partitions: int, equal_size: bool partition_sums[min_idx] += seqlen if equal_size: for i, partition in enumerate(partitions): - assert len(partition) * k_partitions == len(seqlen_list), ( - f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" - ) + assert len(partition) * k_partitions == len( + seqlen_list + ), f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" return partitions -def get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool): +def get_seqlen_balanced_partitions( + seqlen_list: list[int], k_partitions: int, equal_size: bool +): """ Calculates partitions of indices from seqlen_list such that the sum of sequence lengths in each partition is balanced. Uses the Karmarkar-Karp differencing method. @@ -171,7 +175,9 @@ def get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, eq AssertionError: If equal_size is True and len(seqlen_list) is not divisible by k_partitions. AssertionError: If any resulting partition is empty. """ - assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" + assert ( + len(seqlen_list) >= k_partitions + ), f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" def _check_and_sort_partitions(partitions): assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" @@ -185,7 +191,9 @@ def _check_and_sort_partitions(partitions): assert seen_idx == set(range(len(seqlen_list))) return sorted_partitions - partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) + partitions = karmarkar_karp( + seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size + ) return _check_and_sort_partitions(partitions) @@ -270,13 +278,15 @@ def rearrange_micro_batches( """ # this is per local micro_bsz max_seq_len = batch["attention_mask"].shape[-1] - assert max_token_len >= max_seq_len, ( - f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" - ) + assert ( + max_token_len >= max_seq_len + ), f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) total_seqlen = seq_len_effective.sum().item() # NOTE: num_microbatches <= batch_size, so take the min of this two. - num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len)) + num_micro_batches = min( + len(seq_len_effective), ceildiv(total_seqlen, max_token_len) + ) if min_num_micro_batch is not None: # used to support pp num_micro_batches = max(min_num_micro_batch, num_micro_batches) @@ -290,7 +300,9 @@ def rearrange_micro_batches( seq_len_effective = seq_len_effective.tolist() assert num_micro_batches <= len(seq_len_effective) - micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) + micro_bsz_idx = get_seqlen_balanced_partitions( + seq_len_effective, num_micro_batches, equal_size=False + ) micro_batches = [] diff --git a/Agent0/executor_train/verl/verl/utils/tokenizer.py b/Agent0/executor_train/verl/verl/utils/tokenizer.py index 668ea3e..1391631 100644 --- a/Agent0/executor_train/verl/verl/utils/tokenizer.py +++ b/Agent0/executor_train/verl/verl/utils/tokenizer.py @@ -27,10 +27,16 @@ def set_pad_token_id(tokenizer): """ if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id - warnings.warn(f"tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}", stacklevel=1) + warnings.warn( + f"tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}", + stacklevel=1, + ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - warnings.warn(f"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}", stacklevel=1) + warnings.warn( + f"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}", + stacklevel=1, + ) def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs): @@ -49,11 +55,16 @@ def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kw """ from transformers import AutoTokenizer - if correct_gemma2 and isinstance(name_or_path, str) and "gemma-2-2b-it" in name_or_path: + if ( + correct_gemma2 + and isinstance(name_or_path, str) + and "gemma-2-2b-it" in name_or_path + ): # the EOS token in gemma2 is ambiguious, which may worsen RL performance. # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a warnings.warn( - "Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.", stacklevel=1 + "Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.", + stacklevel=1, ) kwargs["eos_token"] = "" kwargs["eos_token_id"] = 107 @@ -80,7 +91,10 @@ def hf_processor(name_or_path, **kwargs): processor = None # TODO(haibin.lin): try-catch should be removed after adding transformer version req to setup.py to avoid # silent failure - warnings.warn(f"Failed to create processor: {e}. This may affect multimodal processing", stacklevel=1) + warnings.warn( + f"Failed to create processor: {e}. This may affect multimodal processing", + stacklevel=1, + ) # Avoid load tokenizer, see: # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 if processor is not None and "Processor" not in processor.__class__.__name__: diff --git a/Agent0/executor_train/verl/verl/utils/torch_functional.py b/Agent0/executor_train/verl/verl/utils/torch_functional.py index df91ad7..19adbf4 100644 --- a/Agent0/executor_train/verl/verl/utils/torch_functional.py +++ b/Agent0/executor_train/verl/verl/utils/torch_functional.py @@ -83,7 +83,9 @@ def logprobs_from_logits(logits, labels, inplace_backward=True): last_dim = logits.shape[-1] logits = logits.reshape(-1, last_dim) labels = labels.reshape(-1) - output = logprobs_from_logits_flash_attn(logits, labels, inplace_backward=inplace_backward) + output = logprobs_from_logits_flash_attn( + logits, labels, inplace_backward=inplace_backward + ) output = output.view(*batch_dim) elif NPU_CROSS_ENTROPY_LOSS_AVAILABLE: output = logprobs_from_logits_torch_npu(logits, labels) @@ -94,16 +96,18 @@ def logprobs_from_logits(logits, labels, inplace_backward=True): def logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True): output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward) - assert isinstance(output, tuple), ( - "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." - ) + assert isinstance( + output, tuple + ), "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." return -output[0] def logprobs_from_logits_torch_npu(logits, labels): batch_dim = logits.shape[:-1] logits = logits.reshape(-1, logits.shape[-1]) - loss, _, _, _ = torch_npu.npu_cross_entropy_loss(logits, labels.reshape(-1), reduction="none") + loss, _, _, _ = torch_npu.npu_cross_entropy_loss( + logits, labels.reshape(-1), reduction="none" + ) return -loss.view(*batch_dim) @@ -118,16 +122,26 @@ def logprobs_from_logits_v2(logits: torch.FloatTensor, labels): A memory efficient implementation of logprobs_from_logits """ if logits.dtype in [torch.float32, torch.float64]: - logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + logits_labels = torch.gather( + logits, dim=-1, index=labels.unsqueeze(-1) + ).squeeze(-1) # loop to reduce peak mem consumption - logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits]) - logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + logsumexp_values = torch.stack( + [torch.logsumexp(logit, dim=-1) for logit in logits] + ) + logprobs_labels = ( + logits_labels - logsumexp_values + ) # log_softmax(x_i) = x_i - logsumexp(x) else: # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach logprobs_labels = [] - for row_logits, row_labels in zip(logits, labels, strict=True): # loop to reduce peak mem consumption + for row_logits, row_labels in zip( + logits, labels, strict=True + ): # loop to reduce peak mem consumption row_logprobs = F.log_softmax(row_logits, dim=-1) - row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) + row_logprobs_labels = row_logprobs.gather( + dim=-1, index=row_labels.unsqueeze(-1) + ).squeeze(-1) logprobs_labels.append(row_logprobs_labels) logprobs_labels = torch.stack(logprobs_labels) return logprobs_labels @@ -155,7 +169,9 @@ def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 20 for i in range(0, logits.shape[0], chunk_size): logits_chunk = logits[i : i + chunk_size].float() pd_chunk = torch.nn.functional.softmax(logits_chunk, dim=-1) - entropy_chunk = torch.logsumexp(logits_chunk, dim=-1) - torch.sum(pd_chunk * logits_chunk, dim=-1) + entropy_chunk = torch.logsumexp(logits_chunk, dim=-1) - torch.sum( + pd_chunk * logits_chunk, dim=-1 + ) entropy[i : i + chunk_size] = entropy_chunk return entropy @@ -197,7 +213,9 @@ def masked_var(values, mask, unbiased=True): # note that if mask_sum == 1, then there is a division by zero issue # to avoid it you just need to use a larger minibatch_size if mask_sum == 1: - raise ValueError("The sum of the mask is one, which can cause a division by zero.") + raise ValueError( + "The sum of the mask is one, which can cause a division by zero." + ) bessel_correction = mask_sum / (mask_sum - 1) variance = variance * bessel_correction return variance @@ -223,7 +241,9 @@ def masked_whiten(values, mask, shift_mean=True): return whitened -def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64): +def get_response_mask( + response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64 +): """ end of sentence token can be int or list: 1 or [1, 2] e.g. @@ -242,7 +262,9 @@ def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, [1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0]]) """ - eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int() + eos_mask = torch.isin( + response_id, torch.tensor(eos_token, device=response_id.device) + ).int() return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype) @@ -263,7 +285,9 @@ def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src, gr torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False) -def allgather_dict_tensors(tensors: dict[str, torch.Tensor] | TensorDict, size, group, dim=0): +def allgather_dict_tensors( + tensors: dict[str, torch.Tensor] | TensorDict, size, group, dim=0 +): """ TODO: optimize this. - We can use async ops @@ -298,9 +322,9 @@ def allgather_dict_tensors(tensors: dict[str, torch.Tensor] | TensorDict, size, def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> list[TensorDict]: - assert tensors.batch_size[0] % batch_size == 0, ( - f"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}" - ) + assert ( + tensors.batch_size[0] % batch_size == 0 + ), f"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}" return tensors.split(batch_size) @@ -309,8 +333,15 @@ def pad_2d_list_to_length(response, pad_token_id, max_length=None): pad a 2D list (e.g. responses, logprobs) to a 2D tensor. """ response_length = max(len(sub_list) for sub_list in response) - target_length = max_length if max_length is not None and max_length > response_length else response_length - padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response] + target_length = ( + max_length + if max_length is not None and max_length > response_length + else response_length + ) + padded_response = [ + tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) + for sub_list in response + ] tensor = torch.tensor(padded_response) return tensor @@ -324,7 +355,11 @@ def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False): if tensors.shape[-1] >= max_seq_len: return tensors # (0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad - pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1]) + pad_tuple = ( + (max_seq_len - tensors.shape[-1], 0) + if left_pad + else (0, max_seq_len - tensors.shape[-1]) + ) return F.pad(tensors, pad_tuple, "constant", pad_token_id) @@ -355,7 +390,10 @@ def postprocess_data( sequence_length = input_ids.shape[-1] if sequence_length < max_length: input_ids = pad_sequence_to_length( - input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad + input_ids, + max_seq_len=max_length, + pad_token_id=pad_token_id, + left_pad=left_pad, ) attention_mask = pad_sequence_to_length( attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad @@ -371,10 +409,16 @@ def postprocess_data( elif truncation == "middle": left_half = max_length // 2 right_half = max_length - left_half - input_ids = torch.cat([input_ids[:, :left_half], input_ids[:, -right_half:]], dim=-1) - attention_mask = torch.cat([attention_mask[:, :left_half], attention_mask[:, -right_half:]], dim=-1) + input_ids = torch.cat( + [input_ids[:, :left_half], input_ids[:, -right_half:]], dim=-1 + ) + attention_mask = torch.cat( + [attention_mask[:, :left_half], attention_mask[:, -right_half:]], dim=-1 + ) elif truncation == "error": - raise NotImplementedError(f"{sequence_length=} is larger than {max_length=}") + raise NotImplementedError( + f"{sequence_length=} is larger than {max_length=}" + ) else: raise NotImplementedError(f"Unknown truncation method {truncation}") @@ -382,7 +426,12 @@ def postprocess_data( def tokenize_and_postprocess_data( - prompt: str, tokenizer: PreTrainedTokenizer, max_length: int, pad_token_id: int, left_pad=True, truncation="error" + prompt: str, + tokenizer: PreTrainedTokenizer, + max_length: int, + pad_token_id: int, + left_pad=True, + truncation="error", ): """Tokenize text and process outputs to consistent tensor shapes. @@ -401,7 +450,9 @@ def tokenize_and_postprocess_data( input_ids = input_data["input_ids"] attention_mask = input_data["attention_mask"] - return postprocess_data(input_ids, attention_mask, max_length, pad_token_id, left_pad, truncation) + return postprocess_data( + input_ids, attention_mask, max_length, pad_token_id, left_pad, truncation + ) def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor): @@ -435,7 +486,9 @@ def log_probs_from_logits_response(input_ids, logits, response_length): return response_log_prob -def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length): +def log_probs_from_logits_response_rmpad( + input_ids, attention_mask, logits_rmpad, response_length +): """Compute the log_probs from logits with rmpad logits and pad input. Note that logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between logits and input_ids. @@ -451,18 +504,29 @@ def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad from flash_attn.bert_padding import pad_input, unpad_input batch_size, seqlen = input_ids.shape - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask) + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask=attention_mask + ) input_ids_rmpad = input_ids_rmpad.squeeze(-1) input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) - full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) + full_log_probs_rmpad = logprobs_from_logits( + logits=logits_rmpad, labels=input_ids_rmpad_rolled + ) # (total_nnz,) full_output = pad_input( - hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + hidden_states=full_log_probs_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, ) - output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] + output = full_output.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # [batch_size, response_length] return output -def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length): +def log_probs_from_logits_all_rmpad( + input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length +): """Compute the log_probs from logits with rmpad input_ids and logits. Note that logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between logits and input_ids. @@ -479,14 +543,23 @@ def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batc """ from flash_attn.bert_padding import pad_input - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # transpose back to [total_nnz, 1] + input_ids_rmpad = input_ids_rmpad.transpose( + 0, 1 + ) # transpose back to [total_nnz, 1] input_ids_rmpad = input_ids_rmpad.squeeze(-1) input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) - full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) + full_log_probs_rmpad = logprobs_from_logits( + logits=logits_rmpad, labels=input_ids_rmpad_rolled + ) # (total_nnz,) full_output = pad_input( - hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + hidden_states=full_log_probs_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, ) - output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] + output = full_output.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # [batch_size, response_length] return output @@ -542,8 +615,12 @@ def get_cosine_schedule_with_warmup( def lr_lambda(current_step): if current_step < num_warmup_steps: - return min_lr_ratio + (1.0 - min_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps))) - progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return min_lr_ratio + (1.0 - min_lr_ratio) * ( + float(current_step) / float(max(1, num_warmup_steps)) + ) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) x = math.cos(math.pi * float(num_cycles) * 2.0 * progress) return max(min_lr_ratio, x * coef + intercept) @@ -588,18 +665,22 @@ def prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds): if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device +): """ Make causal mask used for bi-directional self-attention. """ @@ -623,7 +704,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) def get_unpad_data(attention_mask): @@ -685,8 +768,13 @@ def lr_lambda(current_step): if current_step < num_warmup_steps + num_stable_steps: return 1.0 if current_step < num_training_steps: - progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps)) - value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + progress = float( + current_step - num_warmup_steps - num_stable_steps + ) / float(max(1, num_decay_steps)) + value = max( + 0.0, + 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), + ) return (1.0 - min_lr_ratio) * value + min_lr_ratio return min_lr_ratio @@ -701,12 +789,18 @@ def check_device_is_available(): This context manager checks if CUDA is available and raises an error if it is not. """ if not get_torch_device().is_available(): - raise RuntimeError("Device {} must be initialized before importing this module.".format(get_device_name())) + raise RuntimeError( + "Device {} must be initialized before importing this module.".format( + get_device_name() + ) + ) yield -def distributed_mean_max_min_std(local_tensor, compute_max=True, compute_min=True, compute_std=True): +def distributed_mean_max_min_std( + local_tensor, compute_max=True, compute_min=True, compute_std=True +): """Compute distributed statistics across all processes. Args: diff --git a/Agent0/executor_train/verl/verl/utils/tracking.py b/Agent0/executor_train/verl/verl/utils/tracking.py index 07f45a3..867302c 100644 --- a/Agent0/executor_train/verl/verl/utils/tracking.py +++ b/Agent0/executor_train/verl/verl/utils/tracking.py @@ -34,16 +34,34 @@ class Tracking: logger: Dictionary of initialized logger instances for each backend. """ - supported_backend = ["wandb", "mlflow", "swanlab", "vemlp_wandb", "tensorboard", "console", "clearml"] - - def __init__(self, project_name, experiment_name, default_backend: str | list[str] = "console", config=None): + supported_backend = [ + "wandb", + "mlflow", + "swanlab", + "vemlp_wandb", + "tensorboard", + "console", + "clearml", + ] + + def __init__( + self, + project_name, + experiment_name, + default_backend: str | list[str] = "console", + config=None, + ): if isinstance(default_backend, str): default_backend = [default_backend] for backend in default_backend: if backend == "tracking": import warnings - warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning, stacklevel=2) + warnings.warn( + "`tracking` logger is deprecated. use `wandb` instead.", + DeprecationWarning, + stacklevel=2, + ) else: assert backend in self.supported_backend, f"{backend} is not supported" @@ -55,7 +73,12 @@ def __init__(self, project_name, experiment_name, default_backend: str | list[st settings = None if config and config["trainer"].get("wandb_proxy", None): settings = wandb.Settings(https_proxy=config["trainer"]["wandb_proxy"]) - wandb.init(project=project_name, name=experiment_name, config=config, settings=settings) + wandb.init( + project=project_name, + name=experiment_name, + config=config, + settings=settings, + ) self.logger["wandb"] = wandb if "mlflow" in default_backend: @@ -70,7 +93,9 @@ def __init__(self, project_name, experiment_name, default_backend: str | list[st # Project_name is actually experiment_name in MLFlow # If experiment does not exist, will create a new experiment experiment = mlflow.set_experiment(project_name) - mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name) + mlflow.start_run( + experiment_id=experiment.experiment_id, run_name=experiment_name + ) mlflow.log_params(_compute_mlflow_params_from_objects(config)) self.logger["mlflow"] = _MlflowLoggingAdapter() @@ -83,10 +108,14 @@ def __init__(self, project_name, experiment_name, default_backend: str | list[st SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog") SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud") if SWANLAB_API_KEY: - swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten + swanlab.login( + SWANLAB_API_KEY + ) # NOTE: previous login information will be overwritten if config is None: - config = {} # make sure config is not None, otherwise **config will raise error + config = ( + {} + ) # make sure config is not None, otherwise **config will raise error swanlab.init( project=project_name, experiment_name=experiment_name, @@ -117,7 +146,9 @@ def __init__(self, project_name, experiment_name, default_backend: str | list[st self.logger["vemlp_wandb"] = vemlp_wandb if "tensorboard" in default_backend: - self.logger["tensorboard"] = _TensorboardAdapter(project_name, experiment_name) + self.logger["tensorboard"] = _TensorboardAdapter( + project_name, experiment_name + ) if "console" in default_backend: from verl.utils.logger import LocalLogger @@ -126,7 +157,9 @@ def __init__(self, project_name, experiment_name, default_backend: str | list[st self.logger["console"] = self.console_logger if "clearml" in default_backend: - self.logger["clearml"] = ClearMLLogger(project_name, experiment_name, config) + self.logger["clearml"] = ClearMLLogger( + project_name, experiment_name, config + ) def log(self, data, step, backend=None): for default_backend, logger_instance in self.logger.items(): @@ -205,7 +238,9 @@ def __init__(self, project_name, experiment_name): from torch.utils.tensorboard import SummaryWriter - tensorboard_dir = os.environ.get("TENSORBOARD_DIR", f"tensorboard_log/{project_name}/{experiment_name}") + tensorboard_dir = os.environ.get( + "TENSORBOARD_DIR", f"tensorboard_log/{project_name}/{experiment_name}" + ) os.makedirs(tensorboard_dir, exist_ok=True) print(f"Saving tensorboard log to {tensorboard_dir}.") self.writer = SummaryWriter(tensorboard_dir) @@ -230,11 +265,17 @@ def _compute_mlflow_params_from_objects(params) -> dict[str, Any]: if params is None: return {} - return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep="/") + return _flatten_dict( + _transform_params_to_json_serializable(params, convert_list_to_dict=True), + sep="/", + ) def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): - _transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict) + _transform = partial( + _transform_params_to_json_serializable, + convert_list_to_dict=convert_list_to_dict, + ) if dataclasses.is_dataclass(x): return _transform(dataclasses.asdict(x)) @@ -242,7 +283,9 @@ def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): return {k: _transform(v) for k, v in x.items()} if isinstance(x, list): if convert_list_to_dict: - return {"list_len": len(x)} | {f"{i}": _transform(v) for i, v in enumerate(x)} + return {"list_len": len(x)} | { + f"{i}": _transform(v) for i, v in enumerate(x) + } else: return [_transform(v) for v in x] if isinstance(x, Path): @@ -294,7 +337,11 @@ def _log_generations_to_wandb(self, samples, step, wandb): # Create column names for all samples columns = ["step"] + sum( - [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [] + [ + [f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] + for i in range(len(samples)) + ], + [], ) if not hasattr(self, "validation_table"): @@ -352,7 +399,9 @@ def log_generations_to_mlflow(self, samples, step): json.dump(row_data, file) mlflow.log_artifact(validation_gen_step_file) except Exception as e: - print(f"WARNING: save validation generation file to mlflow failed with error {e}") + print( + f"WARNING: save validation generation file to mlflow failed with error {e}" + ) def log_generations_to_clearml(self, samples, step): """Log validation generation to clearml as table""" diff --git a/Agent0/executor_train/verl/verl/utils/ulysses.py b/Agent0/executor_train/verl/verl/utils/ulysses.py index b37c691..a10a51b 100644 --- a/Agent0/executor_train/verl/verl/utils/ulysses.py +++ b/Agent0/executor_train/verl/verl/utils/ulysses.py @@ -83,7 +83,9 @@ def gather_seq_scatter_heads( return x -def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor: +def gather_heads_scatter_seq( + x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None +) -> Tensor: """ A func to sync attention result with alltoall in sequence parallel gather head dimension and scatter seq dim: @@ -114,7 +116,9 @@ def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor: return x[slc] -def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor: +def slice_input_tensor( + x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None +) -> Tensor: group = get_ulysses_sequence_parallel_group() if group is None else group sp_world_size = dist.get_world_size(group) sp_rank = get_ulysses_sequence_parallel_rank() @@ -139,7 +143,10 @@ def all_to_all_tensor( ): group = get_ulysses_sequence_parallel_group() if group is None else group seq_world_size = dist.get_world_size(group) - input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)] + input_list = [ + t.contiguous() + for t in torch.tensor_split(local_input, seq_world_size, scatter_dim) + ] output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op) if async_op: @@ -152,12 +159,18 @@ def wait(): return torch.cat(output_list, dim=gather_dim).contiguous() -def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False): +def all_gather_tensor( + local_tensor: Tensor, + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False, +): group = get_ulysses_sequence_parallel_group() if group is None else group sp_world_size = dist.get_world_size(group=group) output_shape = list(local_tensor.shape) output_shape[0] = output_shape[0] * sp_world_size - output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device) + output = torch.empty( + output_shape, dtype=local_tensor.dtype, device=local_tensor.device + ) dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op) return output @@ -180,10 +193,16 @@ def forward( @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]: - input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() if ctx.async_op else grad_output[0] + input_t = ( + torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() + if ctx.async_op + else grad_output[0] + ) return ( None, - all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False), + all_to_all_tensor( + input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False + ), None, None, None, @@ -226,7 +245,9 @@ def backward(ctx: Any, grad_output: Tensor) -> Any: grad_output = grad_output * ctx.sp_world_size return ( None, - grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(), + grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ + ctx.sp_rank + ].contiguous(), None, None, None, @@ -262,14 +283,20 @@ def gather_outpus_and_unpad( return x x = Gather.apply(group, x, gather_dim, grad_scaler) if unpad_dim is not None: - assert isinstance(padding_size, int), "padding size is not given or is not an integer" + assert isinstance( + padding_size, int + ), "padding size is not given or is not an integer" if padding_size == 0: return x x = _unpad_tensor(x, unpad_dim, padding_size) return x -def ulysses_pad(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1): +def ulysses_pad( + input_ids_rmpad: torch.Tensor, + position_ids_rmpad: Optional[torch.Tensor] = None, + sp_size: int = 1, +): if position_ids_rmpad is not None: assert position_ids_rmpad.size(-2) == 1 assert input_ids_rmpad.size(-1) == position_ids_rmpad.size(-1) @@ -278,9 +305,13 @@ def ulysses_pad(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torc _, total_seq_len = input_ids_rmpad.shape pad_size = (sp_size - total_seq_len % sp_size) % sp_size if pad_size > 0: - input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0) + input_ids_rmpad = torch.nn.functional.pad( + input_ids_rmpad, (0, pad_size), value=0 + ) if position_ids_rmpad is not None: - pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0) + pad_pos_ids = torch.arange( + pad_size, device=position_ids_rmpad.device + ).unsqueeze(0) if position_ids_rmpad.dim() == 3: pad_pos_ids = pad_pos_ids.unsqueeze(0).repeat(3, 1, 1) position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1) @@ -288,7 +319,9 @@ def ulysses_pad(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torc def ulysses_pad_and_slice_inputs( - input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1 + input_ids_rmpad: torch.Tensor, + position_ids_rmpad: Optional[torch.Tensor] = None, + sp_size: int = 1, ): """ Pad and slice input_ids to be divisible by sp_size @@ -308,15 +341,19 @@ def ulysses_pad_and_slice_inputs( torch.Tensor: padded and sliced position_ids int: pad size """ - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(input_ids_rmpad, position_ids_rmpad, sp_size) + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, position_ids_rmpad, sp_size + ) input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False) if position_ids_rmpad is not None: - position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False) + position_ids_rmpad = slice_input_tensor( + position_ids_rmpad, dim=1, padding=False + ) return input_ids_rmpad, position_ids_rmpad, pad_size def validate_ulysses_config(num_heads, ulysses_sequence_size): if ulysses_sequence_size > 1: - assert num_heads % ulysses_sequence_size == 0, ( - f"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})" - ) + assert ( + num_heads % ulysses_sequence_size == 0 + ), f"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})" diff --git a/Agent0/executor_train/verl/verl/utils/vllm_utils.py b/Agent0/executor_train/verl/verl/utils/vllm_utils.py index 25ee665..e9cd6de 100644 --- a/Agent0/executor_train/verl/verl/utils/vllm_utils.py +++ b/Agent0/executor_train/verl/verl/utils/vllm_utils.py @@ -27,7 +27,10 @@ SUPPORTED_MOE_MODELS = [] try: - from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM + from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2ForCausalLM, + DeepseekV3ForCausalLM, + ) SUPPORTED_MOE_MODELS.append(DeepseekV2ForCausalLM) SUPPORTED_MOE_MODELS.append(DeepseekV3ForCausalLM) @@ -92,7 +95,9 @@ def patch_vllm_moe_model_weight_loader(model): model = getattr(model, "model", None) or getattr(model, "language_model", None) if model is None: - raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.") + raise ValueError( + "The provided model does not have a valid 'model' or 'language_model' attribute." + ) for layer in model.layers: mlp_attr = MLP_ATTR_MAPPING.get(type(model), DEFAULT_MLP_ATTR) @@ -143,7 +148,9 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: else: lora_path = get_adapter_absolute_path(lora_request.lora_path) - peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings) + peft_helper = PEFTHelper.from_local_dir( + lora_path, self.max_position_embeddings + ) # Validates the LoRA configuration against requirements before # loading weights, throwing an exception if validation fails. @@ -153,7 +160,10 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: # to ensure correct loading of lora weights. model = self._adapter_manager.model hf_to_vllm_mapper = None - if hasattr(model, "hf_to_vllm_mapper") and model.hf_to_vllm_mapper is not None: + if ( + hasattr(model, "hf_to_vllm_mapper") + and model.hf_to_vllm_mapper is not None + ): hf_to_vllm_mapper = model.hf_to_vllm_mapper if isinstance(lora_request, TensorLoRARequest): @@ -164,7 +174,8 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: device="cpu", dtype=self.lora_config.lora_dtype, embeddings=None, - target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, weights_mapper=hf_to_vllm_mapper, @@ -177,7 +188,8 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, - target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, weights_mapper=hf_to_vllm_mapper, diff --git a/Agent0/executor_train/verl/verl/workers/actor/dp_actor.py b/Agent0/executor_train/verl/verl/workers/actor/dp_actor.py index f18bf6b..7d21f2d 100644 --- a/Agent0/executor_train/verl/verl/workers/actor/dp_actor.py +++ b/Agent0/executor_train/verl/verl/workers/actor/dp_actor.py @@ -27,20 +27,44 @@ import verl.utils.torch_functional as verl_F from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty -from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available +from verl.trainer.ppo.core_algos import ( + agg_loss, + compute_policy_loss, + get_policy_loss_fn, + kl_penalty, +) +from verl.utils.device import ( + get_device_id, + get_device_name, + is_cuda_available, + is_npu_available, +) from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import logprobs_from_logits -from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs +from verl.utils.ulysses import ( + gather_outpus_and_unpad, + ulysses_pad, + ulysses_pad_and_slice_inputs, +) from verl.workers.actor import BasePPOActor if is_cuda_available: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + from flash_attn.bert_padding import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) elif is_npu_available: - from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + from transformers.integrations.npu_flash_attention import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) __all__ = ["DataParallelPPOActor"] @@ -50,7 +74,12 @@ class DataParallelPPOActor(BasePPOActor): - def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None): + def __init__( + self, + config, + actor_module: nn.Module, + actor_optimizer: torch.optim.Optimizer = None, + ): """When optimizer is None, it is Reference Policy""" super().__init__(config) self.actor_module = actor_module @@ -73,7 +102,9 @@ def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim self.compute_entropy_from_logits = ( torch.compile(entropy_from_logits, dynamic=True) - if self.config.get("use_torch_compile", True) # use torch compile by default + if self.config.get( + "use_torch_compile", True + ) # use torch compile by default else entropy_from_logits ) self.device_name = get_device_name() @@ -91,11 +122,14 @@ def _forward_micro_batch( if "multi_modal_inputs" in micro_batch.keys(): if "image_bound" in micro_batch["multi_modal_inputs"][0]: # minicpm-o logic for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]] + multi_modal_inputs[key] = [ + inputs[key] for inputs in micro_batch["multi_modal_inputs"] + ] else: for key in micro_batch["multi_modal_inputs"][0].keys(): multi_modal_inputs[key] = torch.cat( - [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 + [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], + dim=0, ) with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): @@ -105,7 +139,9 @@ def _forward_micro_batch( position_ids = micro_batch["position_ids"] entropy = None if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) + position_ids = position_ids.transpose( + 0, 1 + ) # (bsz, 3, seqlen) -> (3, bsz, seqlen) if self.use_remove_padding: input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( @@ -116,24 +152,35 @@ def _forward_micro_batch( # unpad the position_ids to align the rotary if position_ids.dim() == 3: position_ids_rmpad = ( - index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + index_first_axis( + rearrange(position_ids, "c b s ... -> (b s) c ..."), indices + ) .transpose(0, 1) .unsqueeze(1) ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) else: position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices, ).transpose(0, 1) if "image_bound" in multi_modal_inputs: - from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo + from verl.utils.dataset.vision_utils import ( + process_multi_modal_inputs_for_minicpmo, + ) multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( - input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + input_ids, + attention_mask, + position_ids, + cu_seqlens, + multi_modal_inputs, ) # for compute the log_prob - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + input_ids_rmpad_rolled = torch.roll( + input_ids_rmpad, shifts=-1, dims=1 + ) # (1, total_nnz) # pad and slice the inputs if sp > 1 if self.use_ulysses_sp: @@ -146,10 +193,12 @@ def _forward_micro_batch( sp_size=self.ulysses_sequence_parallel_size, ) else: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, - position_ids_rmpad=position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, + input_ids_rmpad, position_ids_rmpad, pad_size = ( + ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) ) input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( input_ids_rmpad_rolled, @@ -157,7 +206,9 @@ def _forward_micro_batch( sp_size=self.ulysses_sequence_parallel_size, ) - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( + 0 + ) # ((total_nnz / sp) + pad) # only pass input_ids and position_ids to enable flash_attn_varlen extra_args = {} @@ -195,7 +246,9 @@ def _forward_micro_batch( # compute entropy if calculate_entropy: if not self.config.entropy_checkpointing: - entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + entropy_rmpad = self.compute_entropy_from_logits( + logits_rmpad + ) # ((total_nnz / sp) + pad) else: entropy_rmpad = torch.utils.checkpoint.checkpoint( self.compute_entropy_from_logits, logits_rmpad @@ -234,8 +287,12 @@ def _forward_micro_batch( # only return response part: if calculate_entropy: - entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) - log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + entropy = full_entropy.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) + log_probs = full_log_probs.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) else: # not using rmpad and no ulysses sp extra_args = {} @@ -254,19 +311,27 @@ def _forward_micro_batch( if self.use_fused_kernels: log_probs = output.log_probs[:, -response_length - 1 : -1] - entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) + entropy = output.entropy[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) else: logits = output.logits logits.div_(temperature) - logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size) + logits = logits[ + :, -response_length - 1 : -1, : + ] # (bsz, response_length, vocab_size) log_probs = logprobs_from_logits(logits, micro_batch["responses"]) if calculate_entropy: if not self.config.entropy_checkpointing: - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + entropy = verl_F.entropy_from_logits( + logits + ) # (bsz, response_length) else: - entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits) + entropy = torch.utils.checkpoint.checkpoint( + verl_F.entropy_from_logits, logits + ) return entropy, log_probs @@ -274,22 +339,32 @@ def _optimizer_step(self): assert self.config.grad_clip is not None if isinstance(self.actor_module, FSDP): - grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) + grad_norm = self.actor_module.clip_grad_norm_( + max_norm=self.config.grad_clip + ) elif isinstance(self.actor_module, FSDPModule): - grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) + grad_norm = fsdp2_clip_grad_norm_( + self.actor_module.parameters(), max_norm=self.config.grad_clip + ) else: - grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_( + self.actor_module.parameters(), max_norm=self.config.grad_clip + ) # if grad_norm is not finite, skip the update if not torch.isfinite(grad_norm): - print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}") + print( + f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}" + ) self.actor_optimizer.zero_grad() else: self.actor_optimizer.step() return grad_norm @GPUMemoryLogger(role="dp actor", logger=logger) - def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor: + def compute_log_prob( + self, data: DataProto, calculate_entropy=False + ) -> torch.Tensor: """Compute the log probability of the responses given input_ids, attention_mask and position_ids Args: @@ -311,7 +386,9 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te self.actor_module.eval() micro_batch_size = data.meta_info["micro_batch_size"] - temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid silent error use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] def _get_micro_batches(data: DataProto) -> tuple[list, list | None]: @@ -320,17 +397,27 @@ def _get_micro_batches(data: DataProto) -> tuple[list, list | None]: has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch if has_multi_modal_inputs: - all_multi_modal_inputs_list = data.non_tensor_batch["multi_modal_inputs"] + all_multi_modal_inputs_list = data.non_tensor_batch[ + "multi_modal_inputs" + ] if use_dynamic_bsz: - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - rearranged_text_micro_batches, textual_indices = rearrange_micro_batches( - batch=batch, max_token_len=max_token_len + max_token_len = ( + data.meta_info["max_token_len"] + * self.ulysses_sequence_parallel_size + ) + rearranged_text_micro_batches, textual_indices = ( + rearrange_micro_batches( + batch=batch, max_token_len=max_token_len + ) ) final_micro_batches_list = [] for i, text_mb_td in enumerate(rearranged_text_micro_batches): current_original_indices = textual_indices[i] - current_mm_inputs_list = [all_multi_modal_inputs_list[idx] for idx in current_original_indices] + current_mm_inputs_list = [ + all_multi_modal_inputs_list[idx] + for idx in current_original_indices + ] mb_dict = {k: v for k, v in text_mb_td.items()} mb_dict["multi_modal_inputs"] = current_mm_inputs_list @@ -341,8 +428,13 @@ def _get_micro_batches(data: DataProto) -> tuple[list, list | None]: micro_batches_dp = data.chunk(num_micro_batches) return micro_batches_dp, None elif use_dynamic_bsz: - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + max_token_len = ( + data.meta_info["max_token_len"] + * self.ulysses_sequence_parallel_size + ) + micro_batches, indices = rearrange_micro_batches( + batch=batch, max_token_len=max_token_len + ) return micro_batches, indices else: micro_batches = batch.split(micro_batch_size) @@ -357,7 +449,9 @@ def _get_micro_batches(data: DataProto) -> tuple[list, list | None]: micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} with torch.no_grad(): entropy, log_probs = self._forward_micro_batch( - micro_batch, temperature=temperature, calculate_entropy=calculate_entropy + micro_batch, + temperature=temperature, + calculate_entropy=calculate_entropy, ) log_probs_lst.append(log_probs) if calculate_entropy: @@ -369,7 +463,9 @@ def _get_micro_batches(data: DataProto) -> tuple[list, list | None]: entropys = torch.concat(entropy_lst, dim=0) if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" + assert len(indices) == log_probs.size( + 0 + ), f"{len(indices)} vs. {log_probs.size()}" revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) log_probs = log_probs[revert_indices] if calculate_entropy: @@ -382,7 +478,9 @@ def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() - temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + temperature = data.meta_info[ + "temperature" + ] # temperature must be in the data.meta_info to avoid silent error select_keys = [ "responses", @@ -401,9 +499,13 @@ def update_policy(self, data: DataProto): # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size + num_mini_batches = ( + data.batch.batch_size[0] // self.config.ppo_mini_batch_size + ) non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) + dataloader = data.select(select_keys, non_tensor_select_keys).chunk( + num_mini_batches + ) else: dataloader = batch.split(self.config.ppo_mini_batch_size) @@ -415,38 +517,63 @@ def update_policy(self, data: DataProto): if has_multi_modal_inputs: micro_batches = [] if self.config.use_dynamic_bsz: - all_multi_modal_inputs_list = data.non_tensor_batch["multi_modal_inputs"] + all_multi_modal_inputs_list = data.non_tensor_batch[ + "multi_modal_inputs" + ] batch_tensordict_for_rearrange = data.batch - max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - rearranged_text_micro_batches_tds, textual_indices = rearrange_micro_batches( - batch=batch_tensordict_for_rearrange, max_token_len=max_token_len + max_token_len = ( + self.config.ppo_max_token_len_per_gpu + * self.ulysses_sequence_parallel_size + ) + rearranged_text_micro_batches_tds, textual_indices = ( + rearrange_micro_batches( + batch=batch_tensordict_for_rearrange, + max_token_len=max_token_len, + ) ) for current_original_indices, text_mb_td in zip( - textual_indices, rearranged_text_micro_batches_tds, strict=True + textual_indices, + rearranged_text_micro_batches_tds, + strict=True, ): current_mm_inputs_list = [ - all_multi_modal_inputs_list[idx] for idx in current_original_indices + all_multi_modal_inputs_list[idx] + for idx in current_original_indices ] mb_dict = {k: v for k, v in text_mb_td.items()} mb_dict["multi_modal_inputs"] = current_mm_inputs_list micro_batches.append(mb_dict) else: self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.config.ppo_mini_batch_size + // self.config.ppo_micro_batch_size_per_gpu + ) + num_micro_batches = ( + mini_batch.batch.batch_size[0] + // self.config.ppo_micro_batch_size_per_gpu ) - num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + micro_batches = data.select( + select_keys, non_tensor_select_keys + ).chunk(num_micro_batches) elif self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + max_token_len = ( + self.config.ppo_max_token_len_per_gpu + * self.ulysses_sequence_parallel_size + ) + micro_batches, _ = rearrange_micro_batches( + batch=mini_batch, max_token_len=max_token_len + ) else: self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.config.ppo_mini_batch_size + // self.config.ppo_micro_batch_size_per_gpu ) # split batch into micro_batches - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + micro_batches = mini_batch.split( + self.config.ppo_micro_batch_size_per_gpu + ) self.actor_optimizer.zero_grad() @@ -455,29 +582,42 @@ def update_policy(self, data: DataProto): # Support all hardwares if isinstance(data, DataProto): - data = {**data.batch.to(get_device_id()), **data.non_tensor_batch} + data = { + **data.batch.to(get_device_id()), + **data.non_tensor_batch, + } elif isinstance(data, dict): for k, v in data.items(): if isinstance(v, torch.Tensor): data[k] = v.to(get_device_id()) elif k == "multi_modal_inputs" and v is not None: data[k] = [ - {kk: vv.to(get_device_id()) for kk, vv in item_dict.items()} for item_dict in v + { + kk: vv.to(get_device_id()) + for kk, vv in item_dict.items() + } + for item_dict in v ] else: data[k] = v else: - data = data.to(get_device_id()) # actor device is cpu when using offload + data = data.to( + get_device_id() + ) # actor device is cpu when using offload response_mask = data["response_mask"] old_log_prob = data["old_log_probs"] advantages = data["advantages"] clip_ratio = self.config.clip_ratio clip_ratio_low = ( - self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio + self.config.clip_ratio_low + if self.config.clip_ratio_low is not None + else clip_ratio ) clip_ratio_high = ( - self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio + self.config.clip_ratio_high + if self.config.clip_ratio_high is not None + else clip_ratio ) clip_ratio_c = self.config.get("clip_ratio_c", 3.0) entropy_coeff = self.config.entropy_coeff @@ -488,37 +628,47 @@ def update_policy(self, data: DataProto): if entropy_coeff != 0: calculate_entropy = True entropy, log_prob = self._forward_micro_batch( - micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy + micro_batch=data, + temperature=temperature, + calculate_entropy=calculate_entropy, ) loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") if self.config.policy_loss.loss_mode == "vanilla": - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - cliprange=clip_ratio, - cliprange_low=clip_ratio_low, - cliprange_high=clip_ratio_high, - clip_ratio_c=clip_ratio_c, - loss_agg_mode=loss_agg_mode, + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = ( + compute_policy_loss( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + cliprange=clip_ratio, + cliprange_low=clip_ratio_low, + cliprange_high=clip_ratio_high, + clip_ratio_c=clip_ratio_c, + loss_agg_mode=loss_agg_mode, + ) ) else: policy_loss_fn = get_policy_loss_fn(loss_mode) - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=advantages, - loss_agg_mode=loss_agg_mode, - config=self.config, + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = ( + policy_loss_fn( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=advantages, + loss_agg_mode=loss_agg_mode, + config=self.config, + ) ) if entropy_coeff != 0: - entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + entropy_loss = agg_loss( + loss_mat=entropy, + loss_mask=response_mask, + loss_agg_mode=loss_agg_mode, + ) # compute policy loss policy_loss = pg_loss - entropy_loss * entropy_coeff @@ -529,9 +679,15 @@ def update_policy(self, data: DataProto): ref_log_prob = data["ref_log_prob"] # compute kl loss kld = kl_penalty( - logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type + logprob=log_prob, + ref_logprob=ref_log_prob, + kl_penalty=self.config.kl_loss_type, + ) + kl_loss = agg_loss( + loss_mat=kld, + loss_mask=response_mask, + loss_agg_mode=loss_agg_mode, ) - kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() @@ -539,7 +695,9 @@ def update_policy(self, data: DataProto): if self.config.use_dynamic_bsz: # relative to the dynamic bsz - loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) + loss = policy_loss * ( + len(data) / self.config.ppo_mini_batch_size + ) else: loss = policy_loss / self.gradient_accumulation loss.backward() diff --git a/Agent0/executor_train/verl/verl/workers/actor/megatron_actor.py b/Agent0/executor_train/verl/verl/workers/actor/megatron_actor.py index 08238d4..6417f2e 100644 --- a/Agent0/executor_train/verl/verl/workers/actor/megatron_actor.py +++ b/Agent0/executor_train/verl/verl/workers/actor/megatron_actor.py @@ -37,10 +37,18 @@ from torch import nn from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty +from verl.trainer.ppo.core_algos import ( + agg_loss, + compute_policy_loss, + get_policy_loss_fn, + kl_penalty, +) from verl.utils.device import get_device_id, get_torch_device from verl.utils.megatron.pipeline_parallel import make_batch_generator -from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits +from verl.utils.megatron.tensor_parallel import ( + vocab_parallel_entropy, + vocab_parallel_log_probs_from_logits, +) from verl.utils.megatron_utils import get_model_config from verl.utils.profiler import GPUMemoryLogger from verl.utils.profiler.profile import Profiler @@ -152,14 +160,18 @@ def _validate_config(self, config) -> None: """Validate config options not implemented for Megatron backend""" assert config.get("ulysses_sequence_parallel_size", 1) == 1 if config.get("shuffle", False): - assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set" + assert ( + config.data_loader_seed is not None + ), "If shuffle dataloader, seed must be manually set" if config.megatron.tensor_model_parallel_size == 1: print("[Warining] Because actor tp size == 1, set sp to False") config.megatron.sequence_parallel = False self.config = config @GPUMemoryLogger(role="megatron actor", logger=logger) - def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor: + def compute_log_prob( + self, data: DataProto, calculate_entropy=False + ) -> torch.Tensor: """Compute the log probability of the responses given input_ids, attention_mask and position_ids Args: @@ -182,9 +194,13 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) micro_batch_size = data.meta_info.get("micro_batch_size", None) max_token_len = data.meta_info.get("max_token_len", None) - assert micro_batch_size is not None, "micro batch size is needed for forward compute" + assert ( + micro_batch_size is not None + ), "micro batch size is needed for forward compute" if use_dynamic_bsz: - assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + assert ( + max_token_len is not None + ), "max_token_len must be set when use_dynamic_bsz is True" max_token_len = max_token_len * self.config.megatron.context_parallel_size def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): @@ -219,19 +235,29 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): if mpu.is_pipeline_last_stage(ignore_virtual=True): # only on last rank. It should be on every tp rank if calculate_entropy: - log_probs = [o[0]["log_probs"] for o in output["output"]] # (bs, seq_size) + log_probs = [ + o[0]["log_probs"] for o in output["output"] + ] # (bs, seq_size) else: - log_probs = [o["log_probs"] for o in output["output"]] # (bs, seq_size) + log_probs = [ + o["log_probs"] for o in output["output"] + ] # (bs, seq_size) log_probs = torch.cat(log_probs, dim=0).to(torch.float32) if use_dynamic_bsz: indices = output["indices"] indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + assert len(indices) == log_probs.size( + 0 + ), f"{len(indices)} vs. {log_probs.size()}" + revert_indices = torch.tensor( + get_reverse_idx(indices), dtype=torch.long + ) log_probs = log_probs[revert_indices] else: log_probs = torch.empty( - size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device + size=(batch_size, response_length), + dtype=torch.float32, + device=input_ids.device, ) # broadcast across pp ranks @@ -249,12 +275,18 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): if use_dynamic_bsz: indices = output["indices"] indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == entropys.size(0), f"{len(indices)} vs. {entropys.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + assert len(indices) == entropys.size( + 0 + ), f"{len(indices)} vs. {entropys.size()}" + revert_indices = torch.tensor( + get_reverse_idx(indices), dtype=torch.long + ) entropys = entropys[revert_indices] else: entropys = torch.empty( - size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device + size=(batch_size, response_length), + dtype=torch.float32, + device=input_ids.device, ) # broadcast across pp ranks torch.distributed.broadcast( @@ -295,10 +327,19 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: Returns: """ - select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] + select_keys = [ + "responses", + "input_ids", + "attention_mask", + "position_ids", + "old_log_probs", + "advantages", + ] if self.config.use_kl_loss: select_keys.append("ref_log_prob") - self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + self.has_multi_modal_inputs = ( + "multi_modal_inputs" in data.non_tensor_batch.keys() + ) if self.has_multi_modal_inputs: data = data.select(select_keys, ["multi_modal_inputs"]) else: @@ -336,40 +377,56 @@ def forward_backward_batch( ) # split into micro-batches mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) - self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() + self.has_multi_modal_inputs = ( + "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() + ) if self.has_multi_modal_inputs: - mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"] + mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch[ + "multi_modal_inputs" + ] mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor( list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"]))) ).to(torch.int64) - if mini_batch.batch["position_ids"].dim() == 3: # qwen2vl mrope [bs, 3, seq_len] + if ( + mini_batch.batch["position_ids"].dim() == 3 + ): # qwen2vl mrope [bs, 3, seq_len] mini_batch.batch["position_ids"] = mini_batch.batch["position_ids"][ :, 0 ] # mcore patch recompute qwen2vl's pos ids during forward indices = None if use_dynamic_bsz: - assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + assert ( + max_token_len is not None + ), "max_token_len must be set when use_dynamic_bsz is True" vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: - microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage + microbatch_group_size_per_vp_stage = ( + self.tf_config.microbatch_group_size_per_vp_stage + ) micro_batches, indices = rearrange_micro_batches( batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len, ) - assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( + assert ( + len(micro_batches) + % self.tf_config.microbatch_group_size_per_vp_stage + == 0 + ), ( f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " f"{microbatch_group_size_per_vp_stage} for megatron backend" ) else: - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) + micro_batches, indices = rearrange_micro_batches( + batch=mini_batch.batch, max_token_len=max_token_len + ) total_seqlen = max_token_len else: - assert micro_batch_size is not None, ( - "micro_batch_size is needed to be passed in when not using dynamic batch size" - ) + assert ( + micro_batch_size is not None + ), "micro_batch_size is needed to be passed in when not using dynamic batch size" micro_batches = mini_batch.batch.split(micro_batch_size) seq_len = micro_batches[0]["input_ids"].shape[1] total_seqlen = micro_batch_size * seq_len @@ -408,8 +465,16 @@ def loss_func(output, data, meta_info): advantages = data["advantages"] clip_ratio = self.config.clip_ratio - clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio - clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio + clip_ratio_low = ( + self.config.clip_ratio_low + if self.config.clip_ratio_low is not None + else clip_ratio + ) + clip_ratio_high = ( + self.config.clip_ratio_high + if self.config.clip_ratio_high is not None + else clip_ratio + ) clip_ratio_c = self.config.get("clip_ratio_c", 3.0) entropy_coeff = self.config.entropy_coeff @@ -418,16 +483,18 @@ def loss_func(output, data, meta_info): loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") if self.config.policy_loss.loss_mode == "vanilla": - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - cliprange=clip_ratio, - cliprange_low=clip_ratio_low, - cliprange_high=clip_ratio_high, - clip_ratio_c=clip_ratio_c, - loss_agg_mode=loss_agg_mode, + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = ( + compute_policy_loss( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + cliprange=clip_ratio, + cliprange_low=clip_ratio_low, + cliprange_high=clip_ratio_high, + clip_ratio_c=clip_ratio_c, + loss_agg_mode=loss_agg_mode, + ) ) else: @@ -454,7 +521,11 @@ def loss_func(output, data, meta_info): if calculate_entropy: entropy = output["entropy"][:, -response_length - 1 : -1].contiguous() if not forward_only: - entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + entropy_loss = agg_loss( + loss_mat=entropy, + loss_mask=response_mask, + loss_agg_mode=loss_agg_mode, + ) entropy_coeff = meta_info["entropy_coeff"] policy_loss = pg_loss - entropy_coeff * entropy_loss else: @@ -466,8 +537,16 @@ def loss_func(output, data, meta_info): if self.config.use_kl_loss: ref_log_prob = data["ref_log_prob"] # compute kl loss - kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) - kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) + kld = kl_penalty( + logprob=log_prob, + ref_logprob=ref_log_prob, + kl_penalty=self.config.kl_loss_type, + ) + kl_loss = agg_loss( + loss_mat=kld, + loss_mask=response_mask, + loss_agg_mode=self.config.loss_agg_mode, + ) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef metrics["actor/kl_loss"] = kl_loss.detach().item() @@ -490,7 +569,12 @@ def forward_step(batch_iter, model): idxs = batch["multi_modal_inputs_idx"] mmi = batch["multi_modal_inputs"] multi_modal_inputs[key] = torch.cat( - [mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0 + [ + mmi[idx].get(key) + for idx in idxs + if mmi[idx].get(key) is not None + ], + dim=0, ) responses = batch["responses"] response_length = responses.size(1) @@ -500,7 +584,10 @@ def forward_step(batch_iter, model): label_mask[:, : -response_length - 1] = False label_mask[:, -1] = False - from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn + from verl.models.mcore import ( + get_mcore_forward_fn, + get_mcore_forward_fused_fn, + ) if self.use_fused_kernels: forward_fn = get_mcore_forward_fused_fn(self.hf_config) @@ -554,7 +641,9 @@ def logits_processor(logits, label, label_mask): return output, partial(loss_func, data=batch, meta_info=meta_info) # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.actor_module)) + batch_generator = make_batch_generator( + micro_batches, vpp_size=len(self.actor_module) + ) # TODO: we may use the new schedule instead # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) @@ -620,7 +709,10 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: micro_batch_size = self.config.ppo_micro_batch_size_per_gpu max_token_len = None if self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size + max_token_len = ( + self.config.ppo_max_token_len_per_gpu + * self.config.megatron.context_parallel_size + ) metric_micro_batch = self.forward_backward_batch( data, calculate_entropy=calculate_entropy, @@ -632,9 +724,13 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: metric_micro_batch = metric_micro_batch["output"] for metric in metric_micro_batch: # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask - append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics. + append_to_dict( + metrics, metric[0] + ) # append the metric from this micro-batch to global metrics. - update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() + update_successful, grad_norm, num_zeros_in_grad = ( + self.actor_optimizer.step() + ) data = {"actor/grad_norm": grad_norm} append_to_dict(metrics, data) diff --git a/Agent0/executor_train/verl/verl/workers/critic/dp_critic.py b/Agent0/executor_train/verl/verl/workers/critic/dp_critic.py index ac77758..996b453 100644 --- a/Agent0/executor_train/verl/verl/workers/critic/dp_critic.py +++ b/Agent0/executor_train/verl/verl/workers/critic/dp_critic.py @@ -26,7 +26,12 @@ from verl import DataProto from verl.trainer.ppo import core_algos -from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available +from verl.utils.device import ( + get_device_id, + get_device_name, + is_cuda_available, + is_npu_available, +) from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict @@ -36,23 +41,37 @@ from verl.workers.critic import BasePPOCritic if is_cuda_available: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + from flash_attn.bert_padding import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) elif is_npu_available: - from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + from transformers.integrations.npu_flash_attention import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) class DataParallelPPOCritic(BasePPOCritic): - def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer): + def __init__( + self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer + ): super().__init__(config=config) self.critic_module = critic_module self.critic_optimizer = critic_optimizer self.use_remove_padding = self.config.model.get("use_remove_padding", False) print(f"Critic use_remove_padding={self.use_remove_padding}") - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + self.ulysses_sequence_parallel_size = self.config.get( + "ulysses_sequence_parallel_size", 1 + ) self.device_name = get_device_name() def _forward_micro_batch(self, micro_batch): @@ -81,19 +100,26 @@ def _forward_micro_batch(self, micro_batch): # unpad the position_ids to align the rotary if position_ids.dim() == 3: position_ids_rmpad = ( - index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + index_first_axis( + rearrange(position_ids, "c b s ... -> (b s) c ..."), indices + ) .transpose(0, 1) .unsqueeze(1) ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) else: position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices, ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + input_ids_rmpad, position_ids_rmpad, pad_size = ( + ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) ) # only pass input_ids and position_ids to enable flash_attn_varlen @@ -119,7 +145,9 @@ def _forward_micro_batch(self, micro_batch): ) # pad it back - values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) + values = pad_input( + values_rmpad, indices=indices, batch=batch, seqlen=seqlen + ).squeeze(-1) values = values[:, -response_length - 1 : -1] else: output = self.critic_module( @@ -143,9 +171,13 @@ def _optimizer_step(self): if isinstance(self.critic_module, FSDP): grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip) elif isinstance(self.critic_module, FSDPModule): - grad_norm = fsdp2_clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) + grad_norm = fsdp2_clip_grad_norm_( + self.critic_module.parameters(), max_norm=self.config.grad_clip + ) else: - grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_( + self.critic_module.parameters(), max_norm=self.config.grad_clip + ) # if grad_norm is not finite, skip the update if not torch.isfinite(grad_norm): @@ -167,11 +199,17 @@ def compute_values(self, data: DataProto) -> torch.Tensor: if has_multi_modal_inputs: num_micro_batches = data.batch.batch_size[0] // micro_batch_size non_tensor_select_keys = ["multi_modal_inputs"] - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk( + num_micro_batches + ) elif use_dynamic_bsz: # split using dynamic bsz - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + max_token_len = ( + data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + ) + micro_batches, indices = rearrange_micro_batches( + batch=batch, max_token_len=max_token_len + ) else: micro_batches = batch.split(micro_batch_size) @@ -201,16 +239,28 @@ def update_critic(self, data: DataProto): self.critic_module.train() metrics = {} - select_keys = ["input_ids", "responses", "response_mask", "attention_mask", "position_ids", "values", "returns"] + select_keys = [ + "input_ids", + "responses", + "response_mask", + "attention_mask", + "position_ids", + "values", + "returns", + ] batch = data.select(batch_keys=select_keys).batch has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size + num_mini_batches = ( + data.batch.batch_size[0] // self.config.ppo_mini_batch_size + ) non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) + dataloader = data.select(select_keys, non_tensor_select_keys).chunk( + num_mini_batches + ) else: dataloader = batch.split(self.config.ppo_mini_batch_size) @@ -219,18 +269,32 @@ def update_critic(self, data: DataProto): # split batch into micro_batches mini_batch = data if has_multi_modal_inputs: - num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + num_micro_batches = ( + mini_batch.batch.batch_size[0] + // self.config.ppo_micro_batch_size_per_gpu + ) + micro_batches = data.select( + select_keys, non_tensor_select_keys + ).chunk(num_micro_batches) self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.config.ppo_mini_batch_size + // self.config.ppo_micro_batch_size_per_gpu ) elif self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + max_token_len = ( + self.config.ppo_max_token_len_per_gpu + * self.ulysses_sequence_parallel_size + ) + micro_batches, _ = rearrange_micro_batches( + batch=mini_batch, max_token_len=max_token_len + ) else: - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + micro_batches = mini_batch.split( + self.config.ppo_micro_batch_size_per_gpu + ) self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + self.config.ppo_mini_batch_size + // self.config.ppo_micro_batch_size_per_gpu ) self.critic_optimizer.zero_grad() @@ -240,9 +304,14 @@ def update_critic(self, data: DataProto): # Support all devices if isinstance(data, DataProto): - data = {**data.batch.to(get_device_id()), **data.non_tensor_batch} + data = { + **data.batch.to(get_device_id()), + **data.non_tensor_batch, + } else: - data = data.to(get_device_id()) # critic device is cpu when using offload + data = data.to( + get_device_id() + ) # critic device is cpu when using offload response_mask = data["response_mask"] values = data["values"] returns = data["returns"] @@ -271,7 +340,9 @@ def update_critic(self, data: DataProto): { "critic/vf_loss": vf_loss.detach().item(), "critic/vf_clipfrac": vf_clipfrac.detach().item(), - "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), + "critic/vpred_mean": masked_mean(vpreds, response_mask) + .detach() + .item(), } ) diff --git a/Agent0/executor_train/verl/verl/workers/critic/megatron_critic.py b/Agent0/executor_train/verl/verl/workers/critic/megatron_critic.py index 1d44a88..b1331d8 100644 --- a/Agent0/executor_train/verl/verl/workers/critic/megatron_critic.py +++ b/Agent0/executor_train/verl/verl/workers/critic/megatron_critic.py @@ -83,7 +83,9 @@ def _validate_config(self, config) -> None: """Validate config options not implemented for Megatron backend""" assert config.get("ulysses_sequence_parallel_size", 1) == 1 if config.shuffle: - assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set" + assert ( + config.data_loader_seed is not None + ), "If shuffle dataloader, seed must be manually set" if config.megatron.tensor_model_parallel_size == 1: print("[Warining] Because critic tp size == 1, set sp to False") config.megatron.sequence_parallel = False @@ -97,9 +99,13 @@ def compute_values(self, data: DataProto) -> DataProto: use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) micro_batch_size = data.meta_info.get("micro_batch_size", None) max_token_len = data.meta_info.get("max_token_len", None) - assert micro_batch_size is not None, "micro batch size is needed for forward compute" + assert ( + micro_batch_size is not None + ), "micro batch size is needed for forward compute" if use_dynamic_bsz: - assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + assert ( + max_token_len is not None + ), "max_token_len must be set when use_dynamic_bsz is True" max_token_len = max_token_len * self.config.megatron.context_parallel_size response_length = responses.size(1) with torch.no_grad(): @@ -113,13 +119,19 @@ def compute_values(self, data: DataProto) -> DataProto: ) if mpu.is_pipeline_last_stage(ignore_virtual=True): # only on last rank. It should be on every tp rank - values = [o["vpreds"] for o in output["output"]] # (bs, seq_size, vocal_size) + values = [ + o["vpreds"] for o in output["output"] + ] # (bs, seq_size, vocal_size) values = torch.cat(values, dim=0).to(torch.float32) if use_dynamic_bsz: indices = output["indices"] indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + assert len(indices) == values.size( + 0 + ), f"{len(indices)} vs. {values.size()}" + revert_indices = torch.tensor( + get_reverse_idx(indices), dtype=torch.long + ) values = values[revert_indices] else: values = torch.empty_like(attention_mask, dtype=torch.float32) @@ -145,7 +157,14 @@ def compute_values(self, data: DataProto) -> DataProto: return values def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: - select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] + select_keys = [ + "input_ids", + "responses", + "attention_mask", + "position_ids", + "values", + "returns", + ] data = data.select(batch_keys=select_keys) return data.make_iterator( mini_batch_size=self.config.ppo_mini_batch_size, @@ -177,26 +196,36 @@ def forward_backward_batch( indices = None if use_dynamic_bsz: - assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + assert ( + max_token_len is not None + ), "max_token_len must be set when use_dynamic_bsz is True" vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: - microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage + microbatch_group_size_per_vp_stage = ( + self.tf_config.microbatch_group_size_per_vp_stage + ) micro_batches, indices = rearrange_micro_batches( batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len, ) - assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( + assert ( + len(micro_batches) + % self.tf_config.microbatch_group_size_per_vp_stage + == 0 + ), ( f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " f"{microbatch_group_size_per_vp_stage} for megatron backend" ) else: - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) + micro_batches, indices = rearrange_micro_batches( + batch=mini_batch.batch, max_token_len=max_token_len + ) total_seqlen = max_token_len else: - assert micro_batch_size is not None, ( - "micro_batch_size is needed to be passed in when not using dynamic batch size" - ) + assert ( + micro_batch_size is not None + ), "micro_batch_size is needed to be passed in when not using dynamic batch size" micro_batches = mini_batch.batch.split(micro_batch_size) seq_len = micro_batches[0]["input_ids"].shape[1] total_seqlen = micro_batch_size * seq_len @@ -261,7 +290,9 @@ def forward_step(batch_iter, model): return output, partial(loss_func, data=batch, meta_info={}) # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.critic_module)) + batch_generator = make_batch_generator( + micro_batches, vpp_size=len(self.critic_module) + ) # TODO: we may use the new schedule instead # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) @@ -305,7 +336,10 @@ def update_critic(self, dataloader: Iterable[DataProto]): micro_batch_size = self.config.ppo_micro_batch_size_per_gpu max_token_len = None if self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size + max_token_len = ( + self.config.ppo_max_token_len_per_gpu + * self.config.megatron.context_parallel_size + ) metric_micro_batch = self.forward_backward_batch( data, forward_only=False, @@ -315,7 +349,9 @@ def update_critic(self, dataloader: Iterable[DataProto]): mini_batch_size=self.config.ppo_mini_batch_size, ) metric_micro_batch = metric_micro_batch["output"] - update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step() + update_successful, grad_norm, num_zeros_in_grad = ( + self.critic_optimizer.step() + ) learning_rate = self.critic_optimizer.param_groups[-1]["lr"] data = {"critic/grad_norm": grad_norm, "critic/lr": learning_rate} append_to_dict(metrics, data) @@ -327,7 +363,9 @@ def update_critic(self, dataloader: Iterable[DataProto]): raise NotImplementedError for metric in metric_micro_batch: - append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics. + append_to_dict( + metrics, metric + ) # append the metric from this micro-batch to global metrics. # add empty cache after each compute get_torch_device().empty_cache() diff --git a/Agent0/executor_train/verl/verl/workers/fsdp_workers.py b/Agent0/executor_train/verl/verl/workers/fsdp_workers.py index f9bb475..e74d450 100644 --- a/Agent0/executor_train/verl/verl/workers/fsdp_workers.py +++ b/Agent0/executor_train/verl/verl/workers/fsdp_workers.py @@ -69,7 +69,12 @@ ) from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask -from verl.utils.profiler import DistProfiler, DistProfilerExtension, log_gpu_memory_usage, simple_timer +from verl.utils.profiler import ( + DistProfiler, + DistProfilerExtension, + log_gpu_memory_usage, + simple_timer, +) from verl.utils.profiler.performance import reduce_timing from verl.utils.py_functional import convert_to_regular_types from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager @@ -82,10 +87,14 @@ def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + device_mesh = init_device_mesh( + device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"] + ) else: device_mesh = init_device_mesh( - device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] + device_name, + mesh_shape=(world_size // fsdp_size, fsdp_size), + mesh_dim_names=["ddp", "fsdp"], ) return device_mesh @@ -98,7 +107,9 @@ def get_sharding_strategy(device_mesh): elif device_mesh.ndim == 2: sharding_strategy = ShardingStrategy.HYBRID_SHARD else: - raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") + raise NotImplementedError( + f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2" + ) return sharding_strategy @@ -128,26 +139,44 @@ def __init__(self, config: DictConfig, role: str, **kwargs): # build device mesh for FSDP world_size = torch.distributed.get_world_size() # TODO(sgm): support FSDP hybrid shard for larger model - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size) + self.device_mesh = create_device_mesh( + world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size + ) # build device mesh for Ulysses Sequence Parallel self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) + self.ulysses_sequence_parallel_size = self.config.actor.get( + "ulysses_sequence_parallel_size", 1 + ) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + device_name, + mesh_shape=(dp, self.ulysses_sequence_parallel_size), + mesh_dim_names=["dp", "sp"], ) - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager( + self.ulysses_device_mesh + ) self._lora_rank = self.config.model.get("lora_rank", 0) self._is_lora = self._lora_rank > 0 self.role = role - assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] + assert self.role in [ + "actor", + "rollout", + "ref", + "actor_rollout", + "actor_rollout_ref", + ] self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] - self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] + self._is_rollout = self.role in [ + "rollout", + "actor_rollout", + "actor_rollout_ref", + ] self._is_ref = self.role in ["ref", "actor_rollout_ref"] # TODO(haibin.lin): @@ -158,22 +187,33 @@ def __init__(self, config: DictConfig, role: str, **kwargs): # The benefit of creating the dataclass config is to perform validation during __post_init__ profiler_config = omega_conf_to_dataclass(config.get("profiler")) DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=profiler_config, option=self.profile_option) + self, + DistProfiler( + rank=self.rank, config=profiler_config, option=self.profile_option + ), ) self._is_offload_param = False self._is_offload_optimizer = False if self._is_actor: - self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False) - self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False) + self._is_offload_param = self.config.actor.fsdp_config.get( + "param_offload", False + ) + self._is_offload_optimizer = self.config.actor.fsdp_config.get( + "optimizer_offload", False + ) elif self._is_ref: # TODO: it seems that manual offload is slowly than FSDP offload - self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False) + self._is_offload_param = self.config.ref.fsdp_config.get( + "param_offload", False + ) # normalize config if self._is_actor: self.config.actor.ppo_mini_batch_size *= self.config.rollout.n - self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size + self.config.actor.ppo_mini_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) assert self.config.actor.ppo_mini_batch_size > 0, ( f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after " f"normalization" @@ -183,28 +223,47 @@ def __init__(self, config: DictConfig, role: str, **kwargs): self.config.actor.ppo_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) - self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size + self.config.actor.ppo_micro_batch_size_per_gpu = ( + self.config.actor.ppo_micro_batch_size + ) if self.config.actor.ppo_micro_batch_size_per_gpu is not None: - assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( + assert ( + self.config.actor.ppo_mini_batch_size + % self.config.actor.ppo_micro_batch_size_per_gpu + == 0 + ), ( f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by " f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" ) - assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( + assert ( + self.config.actor.ppo_mini_batch_size + // self.config.actor.ppo_micro_batch_size_per_gpu + > 0 + ), ( f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than " f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" ) # normalize rollout config - if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: + if ( + self._is_rollout + and self.config.rollout.log_prob_micro_batch_size is not None + ): self.config.rollout.log_prob_micro_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) - self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size + self.config.rollout.log_prob_micro_batch_size_per_gpu = ( + self.config.rollout.log_prob_micro_batch_size + ) # normalize ref config if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: - self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size - self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + self.config.ref.log_prob_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) + self.config.ref.log_prob_micro_batch_size_per_gpu = ( + self.config.ref.log_prob_micro_batch_size + ) def _build_model_optimizer( self, @@ -222,9 +281,17 @@ def _build_model_optimizer( ): from torch import optim from torch.distributed.fsdp import CPUOffload, MixedPrecision - from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq + from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForVision2Seq, + ) - from verl.utils.model import get_generation_config, print_model_size, update_model_config + from verl.utils.model import ( + get_generation_config, + print_model_size, + update_model_config, + ) from verl.utils.torch_dtypes import PrecisionType assert role in ["actor", "ref"] @@ -251,14 +318,18 @@ def _build_model_optimizer( # override model kwargs actor_model_config = AutoConfig.from_pretrained( - local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" + local_path, + trust_remote_code=trust_remote_code, + attn_implementation="flash_attention_2", ) # patch for kimi-vl if getattr(actor_model_config, "model_type", None) == "kimi_vl": actor_model_config.text_config.topk_method = "greedy" - self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code) + self.generation_config = get_generation_config( + local_path, trust_remote_code=trust_remote_code + ) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, @@ -266,13 +337,16 @@ def _build_model_optimizer( "pad_token_id": self.tokenizer.pad_token_id, } override_config_kwargs.update(override_model_config) - update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) + update_model_config( + actor_model_config, override_config_kwargs=override_config_kwargs + ) if self.rank == 0: print(f"Model config after override: {actor_model_config}") # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang init_context = get_init_weight_context_manager( - use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh + use_meta_tensor=not actor_model_config.tie_word_embeddings, + mesh=self.device_mesh, ) with init_context(), warnings.catch_warnings(): @@ -291,13 +365,17 @@ def _build_model_optimizer( # Apply Liger kernel to the model if use_liger is set to True if use_liger: - from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance + from liger_kernel.transformers.monkey_patch import ( + _apply_liger_kernel_to_instance, + ) _apply_liger_kernel_to_instance(model=actor_module) fused_kernel_options = self.config.model.get("fused_kernel_options", None) fused_kernels_backend = ( - fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + fused_kernel_options.get("impl_backend", None) + if fused_kernel_options is not None + else None ) apply_monkey_patch( @@ -312,7 +390,9 @@ def _build_model_optimizer( actor_module.to(torch_dtype) if enable_gradient_checkpointing: - actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + actor_module.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) if self._is_lora: print("Applying LoRA to actor module") actor_module.enable_input_require_grads() @@ -321,8 +401,12 @@ def _build_model_optimizer( "task_type": TaskType.CAUSAL_LM, "r": self.config.model.lora_rank, "lora_alpha": self.config.model.lora_alpha, - "target_modules": convert_to_regular_types(self.config.model.target_modules), - "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules), + "target_modules": convert_to_regular_types( + self.config.model.target_modules + ), + "exclude_modules": convert_to_regular_types( + self.config.model.exclude_modules + ), "bias": "none", } actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) @@ -336,15 +420,25 @@ def _build_model_optimizer( # We wrap FSDP for rollout as well mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) - reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) - buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) + param_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("param_dtype", "bf16") + ) + reduce_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("reduce_dtype", "fp32") + ) + buffer_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("buffer_dtype", "fp32") + ) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 buffer_dtype = torch.float32 - mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + mixed_precision = MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, + ) auto_wrap_policy = get_fsdp_wrap_policy( module=actor_module, @@ -378,20 +472,30 @@ def _build_model_optimizer( mixed_precision=mixed_precision, sync_module_states=True, device_mesh=self.device_mesh, - use_orig_params=self.config.actor.fsdp_config.get("use_orig_params", False), - forward_prefetch=self.config.actor.fsdp_config.get("forward_prefetch", False), + use_orig_params=self.config.actor.fsdp_config.get( + "use_orig_params", False + ), + forward_prefetch=self.config.actor.fsdp_config.get( + "forward_prefetch", False + ), ) elif fsdp_strategy == "fsdp2": - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + assert ( + CPUOffloadPolicy is not None + ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" mp_policy = MixedPrecisionPolicy( - param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + cast_forward_inputs=True, ) if role == "actor" and fsdp_config.offload_policy: cpu_offload = CPUOffloadPolicy(pin_memory=True) self._is_offload_param = False self._is_offload_optimizer = False else: - cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) + cpu_offload = ( + None if role == "actor" else CPUOffloadPolicy(pin_memory=True) + ) fsdp_kwargs = { "mesh": fsdp_mesh, @@ -407,13 +511,18 @@ def _build_model_optimizer( raise NotImplementedError(f"not implement {fsdp_strategy}") if enable_activation_offload: - enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing) + enable_activation_offloading( + actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing + ) log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) # TODO: add more optimizer args into config if role == "actor" and optim_config is not None: - from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup + from verl.utils.torch_functional import ( + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + ) actor_optimizer = optim.AdamW( actor_module_fsdp.parameters(), @@ -432,7 +541,9 @@ def _build_model_optimizer( num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + print( + f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}" + ) if warmup_style == "constant": actor_lr_scheduler = get_constant_schedule_with_warmup( @@ -447,14 +558,21 @@ def _build_model_optimizer( num_cycles=num_cycles, ) else: - raise NotImplementedError(f"Warmup style {warmup_style} is not supported") + raise NotImplementedError( + f"Warmup style {warmup_style} is not supported" + ) log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) else: actor_optimizer = None actor_lr_scheduler = None - return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config + return ( + actor_module_fsdp, + actor_optimizer, + actor_lr_scheduler, + actor_model_config, + ) def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh @@ -462,9 +580,9 @@ def _build_rollout(self, trust_remote_code=False): # TODO(sgm): support FSDP hybrid shard for larger model infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, ( - f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - ) + assert ( + self.world_size % infer_tp == 0 + ), f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" rollout_device_mesh = init_device_mesh( device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] ) @@ -473,7 +591,9 @@ def _build_rollout(self, trust_remote_code=False): from verl.workers.rollout import HFRollout from verl.workers.sharding_manager.base import BaseShardingManager - rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout) + rollout = HFRollout( + module=self.actor_module_fsdp, config=self.config.rollout + ) rollout_sharding_manager = BaseShardingManager() # TODO: a sharding manager that do nothing? @@ -481,17 +601,29 @@ def _build_rollout(self, trust_remote_code=False): from verl.workers.rollout.vllm_rollout import vLLMRollout from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager - log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) - local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False)) + log_gpu_memory_usage( + f"Before building {rollout_name} rollout", logger=logger + ) + local_path = copy_to_local( + self.config.model.path, use_shm=self.config.model.get("use_shm", False) + ) lora_kwargs = ( - {"lora_kwargs": {"enable_lora": True, "max_loras": 1, "max_lora_rank": self._lora_rank}} + { + "lora_kwargs": { + "enable_lora": True, + "max_loras": 1, + "max_lora_rank": self._lora_rank, + } + } if self._is_lora else {} ) # lora_kwargs = {} from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout - vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + vllm_rollout_cls = ( + vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + ) rollout = vllm_rollout_cls( model_path=local_path, config=self.config.rollout, @@ -502,7 +634,9 @@ def _build_rollout(self, trust_remote_code=False): **lora_kwargs, ) - log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) + log_gpu_memory_usage( + f"After building {rollout_name} rollout", logger=logger + ) full_params = torch.distributed.get_world_size() == 1 rollout_sharding_manager = FSDPVLLMShardingManager( module=self.actor_module_fsdp, @@ -527,18 +661,26 @@ def _build_rollout(self, trust_remote_code=False): # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and # we import it here use the abs path. # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 - from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager + from verl.workers.sharding_manager.fsdp_sglang import ( + FSDPSGLangShardingManager, + ) local_path = copy_to_local(self.config.model.path) - log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) + log_gpu_memory_usage( + f"Before building {rollout_name} rollout", logger=logger + ) rollout = SGLangRollout( actor_module=local_path, config=self.config.rollout, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, ) - log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) + log_gpu_memory_usage( + f"After building {rollout_name} rollout", logger=logger + ) if torch.distributed.get_world_size() == 1: self.config.rollout.load_format = "dummy_hf" @@ -555,7 +697,9 @@ def _build_rollout(self, trust_remote_code=False): log_gpu_memory_usage("After building sharding manager", logger=logger) else: - raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported") + raise NotImplementedError( + f"Rollout name: {self.config.rollout.name} is not supported" + ) return rollout, rollout_sharding_manager @@ -566,7 +710,9 @@ def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get("external_lib", None)) - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_model_config = OmegaConf.to_container( + self.config.model.get("override_config", OmegaConf.create()) + ) use_remove_padding = self.config.model.get("use_remove_padding", False) use_shm = self.config.model.get("use_shm", False) @@ -594,11 +740,15 @@ def init_model(self): override_model_config=override_model_config, use_remove_padding=use_remove_padding, use_fused_kernels=use_fused_kernels, - enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + enable_gradient_checkpointing=self.config.model.get( + "enable_gradient_checkpointing", False + ), trust_remote_code=self.config.model.get("trust_remote_code", False), use_liger=self.config.model.get("use_liger", False), role="actor", - enable_activation_offload=self.config.model.get("enable_activation_offload", False), + enable_activation_offload=self.config.model.get( + "enable_activation_offload", False + ), ) # get the original unwrapped module @@ -607,11 +757,15 @@ def init_model(self): if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage("After offload actor model during init", logger=logger) + log_gpu_memory_usage( + "After offload actor model during init", logger=logger + ) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + log_gpu_memory_usage( + "After offload actor optimizer during init", logger=logger + ) if self._is_actor: OmegaConf.set_struct(self.config.actor, True) @@ -619,7 +773,9 @@ def init_model(self): self.config.actor.use_remove_padding = use_remove_padding self.config.actor.use_fused_kernels = use_fused_kernels self.actor = DataParallelPPOActor( - config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + config=self.config.actor, + actor_module=self.actor_module_fsdp, + actor_optimizer=self.actor_optimizer, ) if self._is_rollout: @@ -644,7 +800,9 @@ def init_model(self): with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding self.config.ref.use_fused_kernels = use_fused_kernels - self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) + self.ref_policy = DataParallelPPOActor( + config=self.config.ref, actor_module=self.ref_module_fsdp + ) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) @@ -652,7 +810,9 @@ def init_model(self): model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), checkpoint_config=self.config.actor.checkpoint, ) @@ -660,12 +820,16 @@ def init_model(self): # If ActorRolloutRefWorker is initialized as a standalone rollout, # create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout. - checkpoint_contents = OmegaConf.create({"load_contents": ["model"], "save_contents": []}) + checkpoint_contents = OmegaConf.create( + {"load_contents": ["model"], "save_contents": []} + ) self.checkpoint_manager = FSDPCheckpointManager( model=self.actor_module_fsdp, optimizer=None, lr_scheduler=None, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), checkpoint_config=checkpoint_contents, ) @@ -679,7 +843,9 @@ def update_actor(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) + load_fsdp_optimizer( + optimizer=self.actor_optimizer, device_id=get_device_id() + ) with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) @@ -688,13 +854,24 @@ def update_actor(self, data: DataProto): metrics = self.actor.update_policy(data=data) delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + estimated_flops, promised_flops = self.flops_counter.estimate_flops( + global_num_tokens, delta_time + ) metrics["perf/mfu/actor"] = ( - estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size + estimated_flops + * self.config.actor.ppo_epochs + / promised_flops + / self.world_size + ) + metrics["perf/max_memory_allocated_gb"] = ( + get_torch_device().max_memory_allocated() / (1024**3) + ) + metrics["perf/max_memory_reserved_gb"] = ( + get_torch_device().max_memory_reserved() / (1024**3) + ) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / ( + 1024**3 ) - metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) - metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) - metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) lr = self.actor_lr_scheduler.get_last_lr()[0] metrics["actor/lr"] = lr @@ -708,10 +885,14 @@ def update_actor(self, data: DataProto): if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage("After offload actor model during update_actor", logger=logger) + log_gpu_memory_usage( + "After offload actor model during update_actor", logger=logger + ) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) + log_gpu_memory_usage( + "After offload actor optimizer during update_actor", logger=logger + ) return output @@ -724,17 +905,23 @@ def generate_sequences(self, prompts: DataProto): assert self._is_rollout meta_info = { - "eos_token_id": self.generation_config.eos_token_id - if self.generation_config is not None - else self.tokenizer.eos_token_id, - "pad_token_id": self.generation_config.pad_token_id - if self.generation_config is not None - else self.tokenizer.pad_token_id, + "eos_token_id": ( + self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id + ), + "pad_token_id": ( + self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id + ), } prompts.meta_info.update(meta_info) timing_generate = {} with self.rollout_sharding_manager: - log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) + log_gpu_memory_usage( + "After entering rollout sharding manager", logger=logger + ) prompts = self.rollout_sharding_manager.preprocess_data(prompts) with simple_timer("generate_sequences", timing_generate): @@ -768,18 +955,26 @@ def compute_log_prob(self, data: DataProto): from contextlib import nullcontext is_lora = data.meta_info.pop("is_lora", False) - adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() + adapter_ctx = ( + self.actor.actor_module.disable_adapter() if is_lora else nullcontext() + ) data = data.to(get_device_id()) # we should always recompute old_log_probs when it is HybridEngine - data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info["micro_batch_size"] = ( + self.config.rollout.log_prob_micro_batch_size_per_gpu + ) + data.meta_info["max_token_len"] = ( + self.config.rollout.log_prob_max_token_len_per_gpu + ) data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature # perform recompute log_prob with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) with adapter_ctx: - output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) + output, entropys = self.actor.compute_log_prob( + data=data, calculate_entropy=True + ) output = DataProto.from_dict( tensors={"old_log_probs": output, "entropys": entropys}, meta_info={"temperature": self.config.rollout.temperature}, @@ -795,7 +990,9 @@ def compute_log_prob(self, data: DataProto): if self._is_offload_param: offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger) + log_gpu_memory_usage( + "After offload actor model during compute_log_prob", logger=logger + ) return output @@ -807,7 +1004,9 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["is_lora"] = True data = self.compute_log_prob(data) # this old_log_probs is in fact ref_log_prob - data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]}) + data = DataProto.from_dict( + tensors={"ref_log_prob": data.batch["old_log_probs"]} + ) return data assert self._is_ref # else: @@ -822,7 +1021,9 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data) - output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) + output, _ = self.ref_policy.compute_log_prob( + data=data, calculate_entropy=False + ) output = DataProto.from_dict(tensors={"ref_log_prob": output}) output = self.ulysses_sharding_manager.postprocess_data(output) @@ -836,7 +1037,9 @@ def compute_ref_log_prob(self, data: DataProto): return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + def save_checkpoint( + self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None + ): from verl.utils.logger import log_with_rank # only support save and load ckpt for actor @@ -846,11 +1049,16 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to load_fsdp_model_to_gpu(self.actor_module_fsdp) self.checkpoint_manager.save_checkpoint( - local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + local_path=local_path, + hdfs_path=hdfs_path, + global_step=global_step, + max_ckpt_to_keep=max_ckpt_to_keep, ) dist.barrier() - if self._is_lora and hasattr(getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"): + if self._is_lora and hasattr( + getattr(self, "actor_module", self.actor_module_fsdp), "peft_config" + ): lora_save_path = os.path.join(local_path, "lora_adapter") peft_model = getattr(self, "actor_module", self.actor_module_fsdp) peft_config = {} @@ -862,15 +1070,27 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to peft_config["target_modules"] = list(peft_config["target_modules"]) try: if fsdp_version(self.actor_module_fsdp) > 0: - self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name()) + self.actor_module_fsdp = self.actor_module_fsdp.to( + get_device_name() + ) lora_params = layered_summon_lora_params(self.actor_module_fsdp) if dist.get_rank() == 0: - save_file(lora_params, os.path.join(lora_save_path, "adapter_model.safetensors")) - with open(os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8") as f: + save_file( + lora_params, + os.path.join(lora_save_path, "adapter_model.safetensors"), + ) + with open( + os.path.join(lora_save_path, "adapter_config.json"), + "w", + encoding="utf-8", + ) as f: json.dump(peft_config, f, ensure_ascii=False, indent=4) except Exception as e: log_with_rank( - f"Save LoRA Adapter Error ({e})", rank=dist.get_rank(), logger=logger, log_only_rank_0=True + f"Save LoRA Adapter Error ({e})", + rank=dist.get_rank(), + logger=logger, + log_only_rank_0=True, ) dist.barrier() @@ -895,7 +1115,9 @@ def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False load_fsdp_model_to_gpu(self.actor_module_fsdp) self.checkpoint_manager.load_checkpoint( - local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + local_path=local_path, + hdfs_path=hdfs_path, + del_local_after_load=del_local_after_load, ) if self._is_offload_param: @@ -919,13 +1141,17 @@ class CriticWorker(Worker, DistProfilerExtension): def __init__(self, config): Worker.__init__(self) DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) + self, + DistProfiler( + rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler")) + ), ) import torch.distributed if not torch.distributed.is_initialized(): torch.distributed.init_process_group( - backend=get_nccl_backend(), init_method=os.environ.get("DIST_INIT_METHOD", None) + backend=get_nccl_backend(), + init_method=os.environ.get("DIST_INIT_METHOD", None), ) self.config = config @@ -934,17 +1160,25 @@ def __init__(self, config): from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.config.model.fsdp_config.fsdp_size - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + self.device_mesh = create_device_mesh( + world_size=world_size, fsdp_size=fsdp_size + ) self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + self.ulysses_sequence_parallel_size = self.config.get( + "ulysses_sequence_parallel_size", 1 + ) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + device_name, + mesh_shape=(dp, self.ulysses_sequence_parallel_size), + mesh_dim_names=["dp", "sp"], ) - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager( + self.ulysses_device_mesh + ) # set FSDP offload params self._is_offload_param = self.config.model.fsdp_config.param_offload @@ -952,23 +1186,37 @@ def __init__(self, config): # normalize config self.config.ppo_mini_batch_size *= self.config.rollout_n - self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + self.config.ppo_mini_batch_size //= ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + ) if self.config.ppo_micro_batch_size is not None: self.config.ppo_micro_batch_size //= ( - torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + torch.distributed.get_world_size() + // self.ulysses_sequence_parallel_size ) self.config.forward_micro_batch_size //= ( - torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + torch.distributed.get_world_size() + // self.ulysses_sequence_parallel_size ) self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size - self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size + self.config.forward_micro_batch_size_per_gpu = ( + self.config.forward_micro_batch_size + ) if self.config.ppo_micro_batch_size_per_gpu is not None: - assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( + assert ( + self.config.ppo_mini_batch_size + % self.config.ppo_micro_batch_size_per_gpu + == 0 + ), ( f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by " f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" ) - assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( + assert ( + self.config.ppo_mini_batch_size + // self.config.ppo_micro_batch_size_per_gpu + > 0 + ), ( f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than " f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" ) @@ -988,8 +1236,14 @@ def _build_critic_model_optimizer(self, config): # using random initialized model from any architecture. May not be the same as Actor. tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm) - self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) - self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) + self.tokenizer = hf_tokenizer( + tokenizer_path, + trust_remote_code=config.model.get("trust_remote_code", False), + ) + self.processor = hf_processor( + tokenizer_path, + trust_remote_code=config.model.get("trust_remote_code", False), + ) if self.config.model.get("custom_chat_template", None) is not None: if self.processor is not None: @@ -997,7 +1251,9 @@ def _build_critic_model_optimizer(self, config): else: self.tokenizer.chat_template = self.config.model.custom_chat_template - override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_config = OmegaConf.to_container( + self.config.model.get("override_config", OmegaConf.create()) + ) override_config_kwargs = { "bos_token_id": self.tokenizer.bos_token_id, "eos_token_id": self.tokenizer.eos_token_id, @@ -1023,7 +1279,8 @@ def _build_critic_model_optimizer(self, config): critic_model_config.text_config.topk_method = "greedy" init_context = get_init_weight_context_manager( - use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh + use_meta_tensor=not critic_model_config.tie_word_embeddings, + mesh=self.device_mesh, ) with init_context(), warnings.catch_warnings(): @@ -1051,7 +1308,9 @@ def _build_critic_model_optimizer(self, config): critic_module.to(torch_dtype) if config.model.get("enable_gradient_checkpointing", False): - critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + critic_module.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) if self._is_lora: print("Applying LoRA to critic module") @@ -1061,7 +1320,9 @@ def _build_critic_model_optimizer(self, config): "task_type": TaskType.CAUSAL_LM, "r": self.config.model.lora_rank, "lora_alpha": self.config.model.lora_alpha, - "target_modules": convert_to_regular_types(self.config.model.target_modules), + "target_modules": convert_to_regular_types( + self.config.model.target_modules + ), "bias": "none", } critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) @@ -1074,15 +1335,25 @@ def _build_critic_model_optimizer(self, config): fsdp_config = self.config.model.fsdp_config mixed_precision_config = fsdp_config.get("mixed_precision", None) if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) - reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) - buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) + param_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("param_dtype", "bf16") + ) + reduce_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("reduce_dtype", "fp32") + ) + buffer_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("buffer_dtype", "fp32") + ) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 buffer_dtype = torch.float32 - mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + mixed_precision = MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, + ) auto_wrap_policy = get_fsdp_wrap_policy( module=critic_module, @@ -1111,9 +1382,13 @@ def _build_critic_model_optimizer(self, config): cpu_offload=None, ) elif config.strategy == "fsdp2": - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + assert ( + CPUOffloadPolicy is not None + ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" mp_policy = MixedPrecisionPolicy( - param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + cast_forward_inputs=True, ) offload_policy = None if fsdp_config.offload_policy: @@ -1129,13 +1404,19 @@ def _build_critic_model_optimizer(self, config): } full_state = critic_module.state_dict() apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) - fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy) + fsdp2_load_full_state_dict( + critic_module, full_state, fsdp_mesh, offload_policy + ) else: raise NotImplementedError(f"Unknown strategy {config.strategy}") if config.model.get("enable_activation_offload", False): - enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) - enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing) + enable_gradient_checkpointing = config.model.get( + "enable_gradient_checkpointing", False + ) + enable_activation_offloading( + critic_module, config.strategy, enable_gradient_checkpointing + ) log_gpu_memory_usage("After critic FSDP", logger=None) @@ -1156,7 +1437,10 @@ def _build_critic_model_optimizer(self, config): if self.rank == 0: print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") - from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup + from verl.utils.torch_functional import ( + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + ) if warmup_style == "constant": critic_lr_scheduler = get_constant_schedule_with_warmup( @@ -1164,7 +1448,9 @@ def _build_critic_model_optimizer(self, config): ) elif warmup_style == "cosine": critic_lr_scheduler = get_cosine_schedule_with_warmup( - optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps + optimizer=critic_optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, ) else: raise NotImplementedError(f"Warmup style {warmup_style} is not supported") @@ -1178,19 +1464,25 @@ def init_model(self): from verl.workers.critic import DataParallelPPOCritic - self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( - self.config + self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = ( + self._build_critic_model_optimizer(self.config) ) if self._is_offload_param: offload_fsdp_model_to_cpu(self.critic_module) - log_gpu_memory_usage("After offload critic model during init", logger=logger) + log_gpu_memory_usage( + "After offload critic model during init", logger=logger + ) if self._is_offload_optimizer: offload_fsdp_optimizer(optimizer=self.critic_optimizer) - log_gpu_memory_usage("After offload critic optimizer during init", logger=logger) + log_gpu_memory_usage( + "After offload critic optimizer during init", logger=logger + ) self.critic = DataParallelPPOCritic( - config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer + config=self.config, + critic_module=self.critic_module, + critic_optimizer=self.critic_optimizer, ) self.flops_counter = FlopsCounter(self.critic_model_config) @@ -1198,7 +1490,9 @@ def init_model(self): model=self.critic_module, optimizer=self.critic_optimizer, lr_scheduler=self.critic_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), checkpoint_config=self.config.checkpoint, ) @@ -1234,7 +1528,9 @@ def update_critic(self, data: DataProto): if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id()) + load_fsdp_optimizer( + optimizer=self.critic_optimizer, device_id=get_device_id() + ) # perform forward computation with self.ulysses_sharding_manager: @@ -1245,8 +1541,15 @@ def update_critic(self, data: DataProto): delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + estimated_flops, promised_flops = self.flops_counter.estimate_flops( + global_num_tokens, delta_time + ) + metrics["perf/mfu/critic"] = ( + estimated_flops + * self.config.ppo_epochs + / promised_flops + / self.world_size + ) lr = self.critic_lr_scheduler.get_last_lr()[0] metrics["critic/lr"] = lr @@ -1264,14 +1567,19 @@ def update_critic(self, data: DataProto): return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + def save_checkpoint( + self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None + ): import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) self.checkpoint_manager.save_checkpoint( - local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + local_path=local_path, + hdfs_path=hdfs_path, + global_step=global_step, + max_ckpt_to_keep=max_ckpt_to_keep, ) torch.distributed.barrier() @@ -1286,7 +1594,9 @@ def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True) load_fsdp_model_to_gpu(self.critic_module) self.checkpoint_manager.load_checkpoint( - local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + local_path=local_path, + hdfs_path=hdfs_path, + del_local_after_load=del_local_after_load, ) torch.distributed.barrier() @@ -1306,14 +1616,18 @@ class RewardModelWorker(Worker, DistProfilerExtension): def __init__(self, config): Worker.__init__(self) DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) + self, + DistProfiler( + rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler")) + ), ) import torch.distributed if not torch.distributed.is_initialized(): torch.distributed.init_process_group( - backend=get_nccl_backend(), init_method=os.environ.get("DIST_INIT_METHOD", None) + backend=get_nccl_backend(), + init_method=os.environ.get("DIST_INIT_METHOD", None), ) self.config = config @@ -1322,17 +1636,25 @@ def __init__(self, config): from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.config.model.fsdp_config.fsdp_size - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + self.device_mesh = create_device_mesh( + world_size=world_size, fsdp_size=fsdp_size + ) self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + self.ulysses_sequence_parallel_size = self.config.get( + "ulysses_sequence_parallel_size", 1 + ) dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + device_name, + mesh_shape=(dp, self.ulysses_sequence_parallel_size), + mesh_dim_names=["dp", "sp"], ) - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self.ulysses_sharding_manager = FSDPUlyssesShardingManager( + self.ulysses_device_mesh + ) self.use_remove_padding = self.config.model.get("use_remove_padding", False) @@ -1354,14 +1676,22 @@ def _build_model(self, config): self._do_switch_chat_template = False else: self._do_switch_chat_template = True - input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm) + input_tokenizer_local_path = copy_to_local( + config.model.input_tokenizer, use_shm=use_shm + ) self.input_tokenizer = hf_tokenizer( - input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) + input_tokenizer_local_path, + trust_remote_code=config.model.get("trust_remote_code", False), + ) + self.tokenizer = hf_tokenizer( + local_path, + trust_remote_code=config.model.get("trust_remote_code", False), ) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) trust_remote_code = config.model.get("trust_remote_code", False) - model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + model_config = AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code + ) model_config.num_labels = 1 # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect @@ -1388,7 +1718,9 @@ def _build_model(self, config): reward_module.to(torch.bfloat16) - auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) + auto_wrap_policy = get_fsdp_wrap_policy( + module=reward_module, config=self.config.model.fsdp_config + ) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) @@ -1407,7 +1739,9 @@ def _build_model(self, config): device_mesh=self.device_mesh, ) elif config.strategy == "fsdp2": - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + assert ( + CPUOffloadPolicy is not None + ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" cpu_offload = CPUOffloadPolicy(pin_memory=True) fsdp_kwargs = { "mesh": fsdp_mesh, @@ -1416,7 +1750,9 @@ def _build_model(self, config): } full_state = reward_module.state_dict() apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config) - fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload) + fsdp2_load_full_state_dict( + reward_module, full_state, fsdp_mesh, cpu_offload + ) else: raise NotImplementedError(f"Unknown strategy: {config.strategy}") return reward_module @@ -1429,7 +1765,12 @@ def init_model(self): def _forward_micro_batch(self, micro_batch): if is_cuda_available: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + from flash_attn.bert_padding import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) elif is_npu_available: from transformers.integrations.npu_flash_attention import ( index_first_axis, @@ -1438,15 +1779,22 @@ def _forward_micro_batch(self, micro_batch): unpad_input, ) - from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs + from verl.utils.ulysses import ( + gather_outpus_and_unpad, + ulysses_pad_and_slice_inputs, + ) - with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast( + device_type=device_name, dtype=torch.bfloat16 + ): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] position_ids = micro_batch["position_ids"] if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) + position_ids = position_ids.transpose( + 0, 1 + ) # (bsz, 3, seqlen) -> (3, bsz, seqlen) if self.use_remove_padding: input_ids_rmpad, indices, *_ = unpad_input( @@ -1457,24 +1805,34 @@ def _forward_micro_batch(self, micro_batch): # unpad the position_ids to align the rotary if position_ids.dim() == 3: position_ids_rmpad = ( - index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + index_first_axis( + rearrange(position_ids, "c b s ... -> (b s) c ..."), indices + ) .transpose(0, 1) .unsqueeze(1) ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) else: position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices, ).transpose(0, 1) # pad and slice the inputs if sp > 1 if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + input_ids_rmpad, position_ids_rmpad, pad_size = ( + ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) ) # only pass input_ids and position_ids to enable flash_attn_varlen output = self.reward_module( - input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False, ) reward_rmpad = output.logits reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) @@ -1486,10 +1844,15 @@ def _forward_micro_batch(self, micro_batch): ) # pad it back - rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) + rm_score = pad_input( + reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1) else: output = self.reward_module( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, ) rm_score = output.logits # (batch_size, seq_len, 1) rm_score = rm_score.squeeze(-1) @@ -1508,7 +1871,9 @@ def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): if position_ids.dim() == 3: # qwen2vl mrope [bs, 3, seq_len] position_ids = position_ids[:, 0, :] eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) - token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) + token_level_scores = torch.zeros_like( + attention_mask, dtype=scores.dtype + ) # (bsz, seqlen) token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores # select the response part @@ -1535,7 +1900,9 @@ def _switch_chat_template(self, data: DataProto): # extract response response_ids = data.batch["responses"][i] response_length = response_ids.shape[-1] - valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() + valid_response_length = data.batch["attention_mask"][i][ + -response_length: + ].sum() valid_response_ids = response_ids[:valid_response_length] # decode @@ -1557,7 +1924,9 @@ def _switch_chat_template(self, data: DataProto): if max_length is None: max_length = src_max_length - model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) + model_inputs = target_tokenizer( + prompt_with_chat_template, return_tensors="pt", add_special_tokens=False + ) input_ids, attention_mask = verl_F.postprocess_data( input_ids=model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], @@ -1575,7 +1944,11 @@ def _switch_chat_template(self, data: DataProto): rm_position_ids = compute_position_id_with_mask(rm_attention_mask) - rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} + rm_inputs = { + "input_ids": rm_input_ids, + "attention_mask": rm_attention_mask, + "position_ids": rm_position_ids, + } return DataProto.from_dict(rm_inputs) @@ -1611,10 +1984,17 @@ def compute_rm_score(self, data: DataProto): use_dynamic_bsz = self.config.use_dynamic_bsz if use_dynamic_bsz: - max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) + max_token_len = ( + self.config.forward_max_token_len_per_gpu + * self.ulysses_sequence_parallel_size + ) + micro_batches, indices = rearrange_micro_batches( + batch=rm_data.batch, max_token_len=max_token_len + ) else: - micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) + micro_batches = rm_data.batch.split( + self.config.micro_batch_size_per_gpu + ) output = [] for micro_batch in micro_batches: rm_score = self._forward_micro_batch(micro_batch) @@ -1623,8 +2003,12 @@ def compute_rm_score(self, data: DataProto): if use_dynamic_bsz: indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + assert len(indices) == scores.size( + 0 + ), f"{len(indices)} vs. {scores.size()}" + revert_indices = torch.tensor( + get_reverse_idx(indices), dtype=torch.long + ) scores = scores[revert_indices] token_level_scores = self._expand_to_token_level(data, scores) @@ -1660,7 +2044,9 @@ def _build_rollout(self, trust_remote_code=False): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): - raise NotImplementedError("AsyncActorRolloutRefWorker does not support generate_sequences") + raise NotImplementedError( + "AsyncActorRolloutRefWorker does not support generate_sequences" + ) # ============================ vLLM related ============================ @@ -1681,7 +2067,9 @@ async def chat_completion(self, json_request): return ret @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: + async def generate( + self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str + ) -> list[int]: ret = await self.rollout.generate(prompt_ids, sampling_params, request_id) return ret diff --git a/Agent0/executor_train/verl/verl/workers/megatron_workers.py b/Agent0/executor_train/verl/verl/workers/megatron_workers.py index e761f0e..2ad10af 100644 --- a/Agent0/executor_train/verl/verl/workers/megatron_workers.py +++ b/Agent0/executor_train/verl/verl/workers/megatron_workers.py @@ -34,7 +34,12 @@ from verl.utils import hf_tokenizer from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager from verl.utils.config import omega_conf_to_dataclass -from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device +from verl.utils.device import ( + get_device_id, + get_device_name, + get_nccl_backend, + get_torch_device, +) from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.megatron_utils import ( @@ -43,7 +48,11 @@ offload_megatron_model_to_cpu, offload_megatron_optimizer, ) -from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights +from verl.utils.model import ( + get_hf_model_path, + load_mcore_dist_weights, + load_megatron_gptmodel_weights, +) from verl.utils.profiler import ( DistProfiler, DistProfilerExtension, @@ -99,7 +108,9 @@ def __init__(self, config: DictConfig, role: str, **kwargs): rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group( backend=get_nccl_backend(), - timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + timeout=datetime.timedelta( + seconds=self.config.get("nccl_timeout", 600) + ), init_method=os.environ.get("DIST_INIT_METHOD", None), ) get_torch_device().set_device(rank) @@ -121,14 +132,26 @@ def __init__(self, config: DictConfig, role: str, **kwargs): set_random_seed(seed=self.config.actor.megatron.seed) self.role = role - assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] + assert self.role in [ + "actor", + "rollout", + "ref", + "actor_rollout", + "actor_rollout_ref", + ] self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] - self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] + self._is_rollout = self.role in [ + "rollout", + "actor_rollout", + "actor_rollout_ref", + ] self._is_ref = self.role in ["ref", "actor_rollout_ref"] profiler_config = omega_conf_to_dataclass(config.get("profiler")) - DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config) + ) # TODO(sgm): Currently, we only support reference model param offload # will support other offload later @@ -141,27 +164,59 @@ def __init__(self, config: DictConfig, role: str, **kwargs): self.config.actor.ppo_mini_batch_size *= self.config.rollout.n self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() if self.config.actor.get("ppo_micro_batch_size", None): - self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() - self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() - self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size - self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size - - self._is_offload_param = self.config.actor.megatron.get("param_offload", False) - self._is_offload_grad = self.config.actor.megatron.get("grad_offload", False) - self._is_offload_optimizer = self.config.actor.megatron.get("optimizer_offload", False) + self.config.actor.ppo_micro_batch_size //= ( + mpu.get_data_parallel_world_size() + ) + self.config.rollout.log_prob_micro_batch_size //= ( + mpu.get_data_parallel_world_size() + ) + self.config.actor.ppo_micro_batch_size_per_gpu = ( + self.config.actor.ppo_micro_batch_size + ) + self.config.rollout.log_prob_micro_batch_size_per_gpu = ( + self.config.rollout.log_prob_micro_batch_size + ) + + self._is_offload_param = self.config.actor.megatron.get( + "param_offload", False + ) + self._is_offload_grad = self.config.actor.megatron.get( + "grad_offload", False + ) + self._is_offload_optimizer = self.config.actor.megatron.get( + "optimizer_offload", False + ) elif self._is_ref: if self.config.ref.get("log_prob_micro_batch_size", None): - self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() - self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + self.config.ref.log_prob_micro_batch_size //= ( + mpu.get_data_parallel_world_size() + ) + self.config.ref.log_prob_micro_batch_size_per_gpu = ( + self.config.ref.log_prob_micro_batch_size + ) else: - assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, ( + assert ( + self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) + is not None + ), ( "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and " "`log_prob_micro_batch_size` should not be None at the same time." ) - self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False) + self._ref_is_offload_param = self.config.ref.megatron.get( + "param_offload", False + ) - def _build_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config): - from verl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler + def _build_model_optimizer( + self, + model_path, + optim_config, + override_model_config, + override_transformer_config, + ): + from verl.utils.megatron.optimizer import ( + get_megatron_optimizer, + get_megatron_optimizer_param_scheduler, + ) from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.model import get_generation_config, print_model_size @@ -181,10 +236,13 @@ def make_model(wrap_with_ddp=False): from verl.models.mcore.mbridge import freeze_moe_router post_model_creation_callbacks = [] - if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): + if override_model_config.get("moe_config", {}).get( + "freeze_moe_router", False + ): post_model_creation_callbacks.append(freeze_moe_router) return self.bridge.get_model( - post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=wrap_with_ddp + post_model_creation_callbacks=post_model_creation_callbacks, + wrap_with_ddp=wrap_with_ddp, ) else: @@ -198,7 +256,9 @@ def megatron_actor_model_provider(pre_process, post_process): post_process, share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, value=False, - freeze_moe_router=override_model_config.get("moe_config", {}).get("freeze_moe_router", False), + freeze_moe_router=override_model_config.get( + "moe_config", {} + ).get("freeze_moe_router", False), ) parallel_model.to(get_device_name()) return parallel_model @@ -215,7 +275,9 @@ def megatron_actor_model_provider(pre_process, post_process): if self.config.actor.load_weight: if self.config.actor.megatron.use_dist_checkpointing: load_mcore_dist_weights( - actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False + actor_module, + self.config.actor.megatron.dist_checkpointing_path, + is_value_model=False, ) else: if self.bridge is not None: @@ -223,7 +285,11 @@ def megatron_actor_model_provider(pre_process, post_process): self.bridge.load_weights(actor_module, local_model_path) else: load_megatron_gptmodel_weights( - self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False + self.config, + self.hf_config, + actor_module, + params_dtype=self.dtype, + is_value_model=False, ) if self.rank == 0: @@ -237,7 +303,9 @@ def megatron_actor_model_provider(pre_process, post_process): print("load ref weight start") if self.config.ref.megatron.use_dist_checkpointing: load_mcore_dist_weights( - ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False + ref_module, + self.config.ref.megatron.dist_checkpointing_path, + is_value_model=False, ) else: if self.bridge is not None: @@ -245,7 +313,11 @@ def megatron_actor_model_provider(pre_process, post_process): self.bridge.load_weights(ref_module, local_model_path) else: load_megatron_gptmodel_weights( - self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False + self.config, + self.hf_config, + ref_module, + params_dtype=self.dtype, + is_value_model=False, ) log_gpu_memory_usage("After ref module init", logger=logger) return ref_module, self.hf_config @@ -253,7 +325,9 @@ def megatron_actor_model_provider(pre_process, post_process): # TODO: add more optimizer args into config if self._is_actor: optim_config_megatron = init_megatron_optim_config(optim_config) - actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron) + actor_optimizer = get_megatron_optimizer( + model=actor_module, config=optim_config_megatron + ) actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler( optimizer=actor_optimizer, config=optim_config ) @@ -264,7 +338,13 @@ def megatron_actor_model_provider(pre_process, post_process): log_gpu_memory_usage("After actor optimizer init", logger=logger) - return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config + return ( + actor_module, + actor_optimizer, + actor_optimizer_scheduler, + self.hf_config, + optim_config, + ) def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh @@ -277,25 +357,33 @@ def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh from verl.workers.rollout.vllm_rollout import vLLMRollout - from verl.workers.sharding_manager.megatron_vllm import MegatronVLLMShardingManager + from verl.workers.sharding_manager.megatron_vllm import ( + MegatronVLLMShardingManager, + ) # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor, # we will reorganize their weight format when resharding from actor to rollout. infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, ( - f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - ) + assert ( + self.world_size % infer_tp == 0 + ), f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" rollout_device_mesh = init_device_mesh( - get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] + get_device_name(), + mesh_shape=(dp, infer_tp), + mesh_dim_names=["dp", "infer_tp"], ) log_gpu_memory_usage("Before building vllm rollout", logger=None) - local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False)) + local_path = copy_to_local( + self.config.model.path, use_shm=self.config.model.get("use_shm", False) + ) from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout - vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + vllm_rollout_cls = ( + vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + ) rollout = vllm_rollout_cls( model_path=local_path, config=self.config.rollout, @@ -309,7 +397,9 @@ def _build_rollout(self, trust_remote_code=False): # perform weight resharding between actor and rollout from verl.models.mcore import get_mcore_weight_converter - weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) + weight_converter = get_mcore_weight_converter( + self.actor_model_config, self.dtype + ) sharding_manager = MegatronVLLMShardingManager( inference_engine=rollout.inference_engine, model_config=self.actor_model_config, @@ -334,32 +424,42 @@ def _build_rollout(self, trust_remote_code=False): # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it # here use the abs path. # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 - from verl.workers.sharding_manager.megatron_sglang import MegatronSGLangShardingManager + from verl.workers.sharding_manager.megatron_sglang import ( + MegatronSGLangShardingManager, + ) infer_tp = self.config.rollout.tensor_model_parallel_size dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, ( - f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - ) + assert ( + self.world_size % infer_tp == 0 + ), f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" rollout_device_mesh = init_device_mesh( "cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp") ) local_path = copy_to_local(self.config.model.path) - log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None) + log_gpu_memory_usage( + f"Before building {self.config.rollout.name} rollout", logger=None + ) rollout = SGLangRollout( actor_module=local_path, config=self.config.rollout, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, device_mesh=rollout_device_mesh, ) - log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None) + log_gpu_memory_usage( + f"After building {self.config.rollout.name} rollout", logger=None + ) from verl.models.mcore import get_mcore_weight_converter - weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) + weight_converter = get_mcore_weight_converter( + self.actor_model_config, self.dtype + ) sharding_manager = MegatronSGLangShardingManager( actor_module=self.actor.actor_module, inference_engine=rollout._engine, @@ -375,7 +475,9 @@ def _build_rollout(self, trust_remote_code=False): log_gpu_memory_usage("After building sharding manager", logger=logger) else: raise NotImplementedError("Only vllmRollout is supported with Megatron now") - print(f"rollout and sharding manager init done sharding_manager: {sharding_manager}") + print( + f"rollout and sharding manager init done sharding_manager: {sharding_manager}" + ) return rollout, sharding_manager @register(dispatch_mode=Dispatch.ONE_TO_ALL) @@ -388,14 +490,22 @@ def init_model(self): from verl.utils.torch_dtypes import PrecisionType - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_model_config = OmegaConf.to_container( + self.config.model.get("override_config", OmegaConf.create()) + ) if self._is_actor: override_transformer_config = OmegaConf.to_container( - self.config.actor.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True + self.config.actor.megatron.get( + "override_transformer_config", OmegaConf.create() + ), + resolve=True, ) elif self._is_ref: override_transformer_config = OmegaConf.to_container( - self.config.ref.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True + self.config.ref.megatron.get( + "override_transformer_config", OmegaConf.create() + ), + resolve=True, ) else: override_transformer_config = None @@ -419,10 +529,14 @@ def init_model(self): ) if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) - log_gpu_memory_usage("After offload actor params and grad during init", logger=logger) + log_gpu_memory_usage( + "After offload actor params and grad during init", logger=logger + ) if self._is_offload_optimizer: offload_megatron_optimizer(self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + log_gpu_memory_usage( + "After offload actor optimizer during init", logger=logger + ) if self._is_actor: OmegaConf.set_struct(self.config.actor, True) @@ -465,7 +579,9 @@ def init_model(self): ) if self._ref_is_offload_param: offload_megatron_model_to_cpu(self.ref_module) - log_gpu_memory_usage("After offload ref params during init", logger=logger) + log_gpu_memory_usage( + "After offload ref params during init", logger=logger + ) if self._is_actor: self.flops_counter = FlopsCounter(self.actor_model_config) @@ -480,7 +596,9 @@ def init_model(self): hf_config=self.hf_config, param_dtype=self.param_dtype, share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), optimizer=self.actor_optimizer, optimizer_scheduler=self.actor_optimizer_scheduler, use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, @@ -498,10 +616,14 @@ def update_actor(self, data: DataProto): assert self._is_actor if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) - log_gpu_memory_usage("After load actor params and grad during update_actor", logger=logger) + log_gpu_memory_usage( + "After load actor params and grad during update_actor", logger=logger + ) if self._is_offload_optimizer: load_megatron_optimizer(self.actor_optimizer) - log_gpu_memory_usage("After load actor optimizer during update_actor", logger=logger) + log_gpu_memory_usage( + "After load actor optimizer during update_actor", logger=logger + ) data.batch = data.batch.to(get_device_name()) micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu @@ -511,10 +633,21 @@ def update_actor(self, data: DataProto): metrics = self.actor.update_policy(dataloader=dataloader) delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size - metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) - metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) + estimated_flops, promised_flops = self.flops_counter.estimate_flops( + global_num_tokens, delta_time + ) + metrics["perf/mfu/actor"] = ( + estimated_flops + * self.config.actor.ppo_epochs + / promised_flops + / self.world_size + ) + metrics["perf/max_memory_allocated_gb"] = ( + get_torch_device().max_memory_allocated() / (1024**3) + ) + metrics["perf/max_memory_reserved_gb"] = ( + get_torch_device().max_memory_reserved() / (1024**3) + ) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) from verl.utils.megatron.optimizer import get_megatron_last_lr @@ -527,10 +660,14 @@ def update_actor(self, data: DataProto): if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) - log_gpu_memory_usage("After offload actor params and grad during update_actor", logger=logger) + log_gpu_memory_usage( + "After offload actor params and grad during update_actor", logger=logger + ) if self._is_offload_optimizer: offload_megatron_optimizer(self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) + log_gpu_memory_usage( + "After offload actor optimizer during update_actor", logger=logger + ) get_torch_device().empty_cache() return output @@ -542,12 +679,16 @@ def generate_sequences(self, prompts: DataProto): assert self._is_rollout prompts.batch = prompts.batch.to(get_device_name()) meta_info = { - "eos_token_id": self.generation_config.eos_token_id - if self.generation_config is not None - else self.tokenizer.eos_token_id, - "pad_token_id": self.generation_config.pad_token_id - if self.generation_config is not None - else self.tokenizer.pad_token_id, + "eos_token_id": ( + self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id + ), + "pad_token_id": ( + self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id + ), } prompts.meta_info.update(meta_info) if self._is_offload_optimizer: @@ -579,7 +720,10 @@ def compute_ref_log_prob(self, data: DataProto): assert self._is_ref if self._ref_is_offload_param: load_megatron_model_to_gpu(self.ref_module, load_grad=False) - log_gpu_memory_usage("After load ref params and grad during compute_ref_log_prob", logger=logger) + log_gpu_memory_usage( + "After load ref params and grad during compute_ref_log_prob", + logger=logger, + ) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu @@ -591,7 +735,10 @@ def compute_ref_log_prob(self, data: DataProto): output = output.to("cpu") if self._ref_is_offload_param: offload_megatron_model_to_cpu(self.ref_module) - log_gpu_memory_usage("After offload ref params and grad during compute_ref_log_prob", logger=logger) + log_gpu_memory_usage( + "After offload ref params and grad during compute_ref_log_prob", + logger=logger, + ) get_torch_device().empty_cache() return output @@ -602,14 +749,23 @@ def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module, load_grad=False) - log_gpu_memory_usage("After load actor params and grad during compute_log_prob", logger=logger) + log_gpu_memory_usage( + "After load actor params and grad during compute_log_prob", + logger=logger, + ) # we should always recompute old_log_probs when it is HybridEngine - data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info["micro_batch_size"] = ( + self.config.rollout.log_prob_micro_batch_size_per_gpu + ) + data.meta_info["max_token_len"] = ( + self.config.rollout.log_prob_max_token_len_per_gpu + ) data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature data = data.to(get_device_id()) - output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) + output, entropys = self.actor.compute_log_prob( + data=data, calculate_entropy=True + ) output = DataProto.from_dict( tensors={"old_log_probs": output, "entropys": entropys}, meta_info={"temperature": self.config.rollout.temperature}, @@ -618,16 +774,23 @@ def compute_log_prob(self, data: DataProto): # clear kv cache if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) - log_gpu_memory_usage("After offload actor params and grad during compute_log_prob", logger=logger) + log_gpu_memory_usage( + "After offload actor params and grad during compute_log_prob", + logger=logger, + ) get_torch_device().empty_cache() return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): + def load_checkpoint( + self, checkpoint_path, hdfs_path=None, del_local_after_load=True + ): if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) self.checkpoint_mananager.load_checkpoint( - local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + local_path=checkpoint_path, + hdfs_path=hdfs_path, + del_local_after_load=del_local_after_load, ) if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) @@ -639,11 +802,16 @@ def load_pretrained_model(self, checkpoint_path, del_local_after_load=True): pass @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + def save_checkpoint( + self, checkpoint_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None + ): if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) self.checkpoint_mananager.save_checkpoint( - local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + local_path=checkpoint_path, + hdfs_path=hdfs_path, + global_step=global_step, + max_ckpt_to_keep=max_ckpt_to_keep, ) torch.distributed.barrier() if self._is_offload_param: @@ -690,7 +858,9 @@ async def chat_completion(self, json_request): return ret @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: + async def generate( + self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str + ) -> list[int]: ret = await self.rollout.generate(prompt_ids, sampling_params, request_id) return ret @@ -713,7 +883,10 @@ class CriticWorker(MegatronWorker, DistProfilerExtension): def __init__(self, config): MegatronWorker.__init__(self) DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) + self, + DistProfiler( + rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler")) + ), ) self.config = config @@ -727,7 +900,9 @@ def __init__(self, config): rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group( backend=get_nccl_backend(), - timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + timeout=datetime.timedelta( + seconds=self.config.get("nccl_timeout", 600) + ), init_method=os.environ.get("DIST_INIT_METHOD", None), ) get_torch_device().set_device(rank) @@ -762,11 +937,18 @@ def __init__(self, config): # TODO(sgm): support critic model offload def _build_critic_model_optimizer( - self, model_path, optim_config, override_model_config, override_transformer_config + self, + model_path, + optim_config, + override_model_config, + override_transformer_config, ): from megatron.core.models.gpt.gpt_model import ModelType - from verl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler + from verl.utils.megatron.optimizer import ( + get_megatron_optimizer, + get_megatron_optimizer_param_scheduler, + ) from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.model import print_model_size @@ -784,10 +966,13 @@ def _build_critic_model_optimizer( from verl.models.mcore.mbridge import freeze_moe_router, make_value_model post_model_creation_callbacks = [make_value_model] - if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): + if override_model_config.get("moe_config", {}).get( + "freeze_moe_router", False + ): post_model_creation_callbacks.append(freeze_moe_router) critic_module = self.bridge.get_model( - post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=True + post_model_creation_callbacks=post_model_creation_callbacks, + wrap_with_ddp=True, ) else: @@ -801,7 +986,9 @@ def megatron_critic_model_provider(pre_process, post_process): post_process, share_embeddings_and_output_weights=False, value=True, - freeze_moe_router=override_model_config.get("moe_config", {}).get("freeze_moe_router", False), + freeze_moe_router=override_model_config.get("moe_config", {}).get( + "freeze_moe_router", False + ), ) parallel_model.to(get_device_name()) return parallel_model @@ -821,7 +1008,9 @@ def megatron_critic_model_provider(pre_process, post_process): t0 = time.time() if self.config.megatron.use_dist_checkpointing: load_mcore_dist_weights( - critic_module, self.config.megatron.dist_checkpointing_path, is_value_model=True + critic_module, + self.config.megatron.dist_checkpointing_path, + is_value_model=True, ) else: if self.bridge is not None: @@ -829,7 +1018,11 @@ def megatron_critic_model_provider(pre_process, post_process): self.bridge.load_weights(critic_module, local_model_path) else: load_megatron_gptmodel_weights( - self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True + self.config, + self.hf_config, + critic_module, + params_dtype=self.dtype, + is_value_model=True, ) t1 = time.time() if torch.distributed.get_rank() == 0: @@ -839,12 +1032,20 @@ def megatron_critic_model_provider(pre_process, post_process): # TODO: add more optimizer args into config optim_config_megatron = init_megatron_optim_config(optim_config) - critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron) + critic_optimizer = get_megatron_optimizer( + model=critic_module, config=optim_config_megatron + ) critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler( optimizer=critic_optimizer, config=optim_config ) get_torch_device().empty_cache() - return critic_module, critic_optimizer, critic_optimizer_scheduler, self.hf_config, optim_config + return ( + critic_module, + critic_optimizer, + critic_optimizer_scheduler, + self.hf_config, + optim_config, + ) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): @@ -857,9 +1058,12 @@ def init_model(self): import importlib importlib.import_module(self.config.model.external_lib) - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_model_config = OmegaConf.to_container( + self.config.model.get("override_config", OmegaConf.create()) + ) override_transformer_config = OmegaConf.to_container( - self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True + self.config.megatron.get("override_transformer_config", OmegaConf.create()), + resolve=True, ) self.param_dtype = torch.bfloat16 self.dtype = PrecisionType.to_dtype(self.param_dtype) @@ -901,7 +1105,9 @@ def init_model(self): hf_config=self.hf_config, param_dtype=self.param_dtype, share_embeddings_and_output_weights=False, - processing_class=self.processor if self.processor is not None else self.tokenizer, + processing_class=( + self.processor if self.processor is not None else self.tokenizer + ), optimizer=self.critic_optimizer, optimizer_scheduler=self.critic_optimizer_scheduler, use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, @@ -942,8 +1148,12 @@ def update_critic(self, data: DataProto): metrics = self.critic.update_critic(dataloader=dataloader) delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + estimated_flops, promised_flops = self.flops_counter.estimate_flops( + global_num_tokens, delta_time + ) + metrics["perf/mfu/critic"] = ( + estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + ) from verl.utils.megatron.optimizer import get_megatron_last_lr metrics["critic/lr"] = get_megatron_last_lr(self.critic_optimizer) @@ -959,11 +1169,15 @@ def update_critic(self, data: DataProto): return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): + def load_checkpoint( + self, checkpoint_path, hdfs_path=None, del_local_after_load=True + ): if self._is_offload_param: load_megatron_model_to_gpu(self.critic_module) self.checkpoint_mananager.load_checkpoint( - local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + local_path=checkpoint_path, + hdfs_path=hdfs_path, + del_local_after_load=del_local_after_load, ) if self._is_offload_param: offload_megatron_model_to_cpu(self.critic_module) @@ -971,11 +1185,16 @@ def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load= offload_megatron_optimizer(self.critic_optimizer) @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_steps=0, max_ckpt_to_keep=None): + def save_checkpoint( + self, checkpoint_path, hdfs_path=None, global_steps=0, max_ckpt_to_keep=None + ): if self._is_offload_param: load_megatron_model_to_gpu(self.critic_module) self.checkpoint_mananager.save_checkpoint( - local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_steps, max_ckpt_to_keep=max_ckpt_to_keep + local_path=checkpoint_path, + hdfs_path=hdfs_path, + global_step=global_steps, + max_ckpt_to_keep=max_ckpt_to_keep, ) if self._is_offload_param: offload_megatron_model_to_cpu(self.critic_module) @@ -989,7 +1208,10 @@ class RewardModelWorker(MegatronWorker, DistProfilerExtension): def __init__(self, config): MegatronWorker.__init__(self) DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) + self, + DistProfiler( + rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler")) + ), ) self.config = config @@ -1003,7 +1225,9 @@ def __init__(self, config): rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group( backend=get_nccl_backend(), - timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + timeout=datetime.timedelta( + seconds=self.config.get("nccl_timeout", 600) + ), init_method=os.environ.get("DIST_INIT_METHOD", None), ) get_torch_device().set_device(rank) @@ -1029,7 +1253,9 @@ def __init__(self, config): self.config.micro_batch_size //= mpu.get_data_parallel_world_size() self.config.micro_batch_size_per_gpu = self.config.micro_batch_size - def _build_rm_model(self, model_path, tokenizer, override_model_config, override_transformer_config): + def _build_rm_model( + self, model_path, tokenizer, override_model_config, override_transformer_config + ): from megatron.core.models.gpt.gpt_model import ModelType from verl.utils.megatron_utils import get_model @@ -1047,10 +1273,13 @@ def _build_rm_model(self, model_path, tokenizer, override_model_config, override from verl.models.mcore.mbridge import freeze_moe_router, make_value_model post_model_creation_callbacks = [make_value_model] - if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): + if override_model_config.get("moe_config", {}).get( + "freeze_moe_router", False + ): post_model_creation_callbacks.append(freeze_moe_router) reward_model = self.bridge.get_model( - post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=False + post_model_creation_callbacks=post_model_creation_callbacks, + wrap_with_ddp=False, ) else: @@ -1081,14 +1310,22 @@ def megatron_rm_model_provider(pre_process, post_process): if self.config.load_weight: if self.config.megatron.use_dist_checkpointing: - load_mcore_dist_weights(reward_model, self.config.megatron.dist_checkpointing_path, is_value_model=True) + load_mcore_dist_weights( + reward_model, + self.config.megatron.dist_checkpointing_path, + is_value_model=True, + ) else: if self.bridge is not None: local_model_path = get_hf_model_path(self.config) self.bridge.load_weights(reward_model, local_model_path) else: load_megatron_gptmodel_weights( - self.config, self.hf_config, reward_model, params_dtype=self.dtype, is_value_model=True + self.config, + self.hf_config, + reward_model, + params_dtype=self.dtype, + is_value_model=True, ) # TODO: add more optimizer args into config @@ -1106,20 +1343,26 @@ def init_model(self): import importlib importlib.import_module(self.config.model.external_lib) - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_model_config = OmegaConf.to_container( + self.config.model.get("override_config", OmegaConf.create()) + ) override_transformer_config = OmegaConf.to_container( - self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True + self.config.megatron.get("override_transformer_config", OmegaConf.create()), + resolve=True, ) use_shm = self.config.model.get("use_shm", False) - sft_tokenizer_local_path = copy_to_local(self.config.model.input_tokenizer, use_shm=use_shm) + sft_tokenizer_local_path = copy_to_local( + self.config.model.input_tokenizer, use_shm=use_shm + ) sft_tokenizer = hf_tokenizer(sft_tokenizer_local_path) rm_tokenizer_path = self.config.model.get("rm_tokenizer", None) rm_tokenizer = None if rm_tokenizer_path is not None: rm_tokenizer_local_path = copy_to_local(rm_tokenizer_path, use_shm=use_shm) rm_tokenizer = hf_tokenizer( - rm_tokenizer_local_path, trust_remote_code=self.config.model.get("trust_remote_code", False) + rm_tokenizer_local_path, + trust_remote_code=self.config.model.get("trust_remote_code", False), ) self.param_dtype = torch.bfloat16 diff --git a/Agent0/executor_train/verl/verl/workers/reward_manager/batch.py b/Agent0/executor_train/verl/verl/workers/reward_manager/batch.py index 8d1b112..eb9d626 100644 --- a/Agent0/executor_train/verl/verl/workers/reward_manager/batch.py +++ b/Agent0/executor_train/verl/verl/workers/reward_manager/batch.py @@ -33,7 +33,14 @@ class BatchRewardManager: reward_kwargs (dict): The keyword arguments to pass to the reward function. """ - def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key="data_source", **reward_kwargs): + def __init__( + self, + tokenizer, + num_examine, + compute_score, + reward_fn_key="data_source", + **reward_kwargs + ): self.tokenizer = tokenizer self.num_examine = num_examine self.compute_score = compute_score @@ -52,10 +59,15 @@ def verify(self, data): for i in range(len(data)): valid_len = valid_response_lengths[i] valid_response_ids = response_ids[i][:valid_len] - response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + response_str = self.tokenizer.decode( + valid_response_ids, skip_special_tokens=True + ) responses_str.append(response_str) - ground_truths = [item.non_tensor_batch["reward_model"].get("ground_truth", None) for item in data] + ground_truths = [ + item.non_tensor_batch["reward_model"].get("ground_truth", None) + for item in data + ] data_sources = data.non_tensor_batch[self.reward_fn_key] extras = data.non_tensor_batch.get("extra_info", [None] * len(data)) @@ -105,18 +117,29 @@ def __call__(self, data: DataProto, return_dict=False): data_source = data_sources[i] if already_printed.get(data_source, 0) < self.num_examine: - response_str = self.tokenizer.decode(data.batch["responses"][i][:length], skip_special_tokens=True) - prompt_str = self.tokenizer.decode(data.batch["prompts"][i], skip_special_tokens=True) - ground_truth = data[i].non_tensor_batch["reward_model"].get("ground_truth", None) + response_str = self.tokenizer.decode( + data.batch["responses"][i][:length], skip_special_tokens=True + ) + prompt_str = self.tokenizer.decode( + data.batch["prompts"][i], skip_special_tokens=True + ) + ground_truth = ( + data[i].non_tensor_batch["reward_model"].get("ground_truth", None) + ) print("[prompt]", prompt_str) print("[response]", response_str) print("[ground_truth]", ground_truth) print("[score]", scores[i]) already_printed[data_source] = already_printed.get(data_source, 0) + 1 - data.batch["acc"] = torch.tensor(rewards, dtype=torch.float32, device=prompt_ids.device) + data.batch["acc"] = torch.tensor( + rewards, dtype=torch.float32, device=prompt_ids.device + ) if return_dict: - return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info} + return { + "reward_tensor": reward_tensor, + "reward_extra_info": reward_extra_info, + } else: return reward_tensor diff --git a/Agent0/executor_train/verl/verl/workers/reward_manager/dapo.py b/Agent0/executor_train/verl/verl/workers/reward_manager/dapo.py index 3ba9afe..15e470d 100644 --- a/Agent0/executor_train/verl/verl/workers/reward_manager/dapo.py +++ b/Agent0/executor_train/verl/verl/workers/reward_manager/dapo.py @@ -42,12 +42,12 @@ def __init__( self.max_resp_len = max_resp_len if self.overlong_buffer_cfg is not None: - assert self.max_resp_len is not None, ( - f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" - ) - assert self.max_resp_len >= self.overlong_buffer_cfg.len, ( - "max_resp_len must be larger than overlong_buffer.len" - ) + assert ( + self.max_resp_len is not None + ), f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" + assert ( + self.max_resp_len >= self.overlong_buffer_cfg.len + ), "max_resp_len must be larger than overlong_buffer.len" def __call__(self, data: DataProto, return_dict: bool = False): """We will expand this function gradually based on the available datasets""" @@ -71,16 +71,24 @@ def __call__(self, data: DataProto, return_dict: bool = False): prompt_length = prompt_ids.shape[-1] - valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() + valid_prompt_length = data_item.batch["attention_mask"][ + :prompt_length + ].sum() valid_prompt_ids = prompt_ids[-valid_prompt_length:] response_ids = data_item.batch["responses"] - valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_length = data_item.batch["attention_mask"][ + prompt_length: + ].sum() valid_response_ids = response_ids[:valid_response_length] # decode - prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) - response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + prompt_str = self.tokenizer.decode( + valid_prompt_ids, skip_special_tokens=True + ) + response_str = self.tokenizer.decode( + valid_response_ids, skip_special_tokens=True + ) eos_token = self.tokenizer.eos_token if response_str.endswith(eos_token): response_str = response_str[: -len(eos_token)] @@ -114,7 +122,9 @@ def __call__(self, data: DataProto, return_dict: bool = False): expected_len = self.max_resp_len - overlong_buffer_len exceed_len = valid_response_length - expected_len overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor - overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) + overlong_reward = min( + -exceed_len / overlong_buffer_len * overlong_penalty_factor, 0 + ) reward += overlong_reward if self.overlong_buffer_cfg.log: reward_extra_info["overlong_reward"].append(overlong_reward) diff --git a/Agent0/executor_train/verl/verl/workers/reward_manager/naive.py b/Agent0/executor_train/verl/verl/workers/reward_manager/naive.py index f6f979e..7e1926d 100644 --- a/Agent0/executor_train/verl/verl/workers/reward_manager/naive.py +++ b/Agent0/executor_train/verl/verl/workers/reward_manager/naive.py @@ -25,7 +25,9 @@ class NaiveRewardManager: """The reward manager.""" - def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None: + def __init__( + self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source" + ) -> None: """ Initialize the NaiveRewardManager instance. @@ -39,7 +41,9 @@ def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="da self.tokenizer = tokenizer # Store the tokenizer for decoding token IDs self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or default_compute_score - self.reward_fn_key = reward_fn_key # Store the key for accessing the data source + self.reward_fn_key = ( + reward_fn_key # Store the key for accessing the data source + ) def __call__(self, data: DataProto, return_dict=False): """We will expand this function gradually based on the available datasets""" @@ -63,16 +67,24 @@ def __call__(self, data: DataProto, return_dict=False): prompt_length = prompt_ids.shape[-1] - valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() + valid_prompt_length = data_item.batch["attention_mask"][ + :prompt_length + ].sum() valid_prompt_ids = prompt_ids[-valid_prompt_length:] response_ids = data_item.batch["responses"] - valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_length = data_item.batch["attention_mask"][ + prompt_length: + ].sum() valid_response_ids = response_ids[:valid_response_length] # decode - prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) - response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + prompt_str = self.tokenizer.decode( + valid_prompt_ids, skip_special_tokens=True + ) + response_str = self.tokenizer.decode( + valid_response_ids, skip_special_tokens=True + ) ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] data_source = data_item.non_tensor_batch[self.reward_fn_key] diff --git a/Agent0/executor_train/verl/verl/workers/reward_manager/prime.py b/Agent0/executor_train/verl/verl/workers/reward_manager/prime.py index f2c526b..60288c0 100644 --- a/Agent0/executor_train/verl/verl/workers/reward_manager/prime.py +++ b/Agent0/executor_train/verl/verl/workers/reward_manager/prime.py @@ -26,11 +26,22 @@ from verl.workers.reward_manager import register -async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0): +async def single_compute_score( + evaluation_func, + completion, + reference, + task, + task_extra_info, + executor, + timeout=300.0, +): loop = asyncio.get_running_loop() try: # Ensure process_completion is called properly - future = loop.run_in_executor(executor, partial(evaluation_func, task, completion, reference, task_extra_info)) + future = loop.run_in_executor( + executor, + partial(evaluation_func, task, completion, reference, task_extra_info), + ) return await asyncio.wait_for(future, timeout=timeout) except asyncio.TimeoutError: print(f"[Timeout] Task timeout: {completion}") @@ -52,8 +63,12 @@ async def parallel_compute_score_async( try: # Create tasks for all rows tasks_async = [ - single_compute_score(evaluation_func, c, r, t, ei, executor, timeout=300.0) - for c, r, t, ei in zip(completions, references, tasks, extra_info, strict=True) + single_compute_score( + evaluation_func, c, r, t, ei, executor, timeout=300.0 + ) + for c, r, t, ei in zip( + completions, references, tasks, extra_info, strict=True + ) ] results = await asyncio.gather(*tasks_async, return_exceptions=False) except Exception as e: @@ -75,7 +90,9 @@ async def parallel_compute_score_async( print(f"[Shutdown] {terminated_count} subprocess(es) terminated.") # Process results - for result, completion, reference, task in zip(results, completions, references, tasks, strict=True): + for result, completion, reference, task in zip( + results, completions, references, tasks, strict=True + ): if isinstance(result, Exception) or result is None: # Handle failed or timed-out tasks scores.append(0.0) @@ -86,12 +103,21 @@ async def parallel_compute_score_async( return scores -def run_reward_scoring(evaluation_func, completions, references, tasks, extra_info=None, num_processes=64): +def run_reward_scoring( + evaluation_func, completions, references, tasks, extra_info=None, num_processes=64 +): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: return loop.run_until_complete( - parallel_compute_score_async(evaluation_func, completions, references, tasks, extra_info, num_processes) + parallel_compute_score_async( + evaluation_func, + completions, + references, + tasks, + extra_info, + num_processes, + ) ) finally: loop.close() @@ -123,8 +149,13 @@ def verify(self, data): prompt_ids = data.batch["prompts"] response_ids = data.batch["responses"] - sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) - ground_truth = [data_item.non_tensor_batch["reward_model"]["ground_truth"] for data_item in data] + sequences_str = self.tokenizer.batch_decode( + response_ids, skip_special_tokens=True + ) + ground_truth = [ + data_item.non_tensor_batch["reward_model"]["ground_truth"] + for data_item in data + ] data_sources = data.non_tensor_batch[self.reward_fn_key] extra_info = data.non_tensor_batch.get("extra_info", None) @@ -144,7 +175,9 @@ def verify(self, data): except Exception as e: print(f"[Error] Unexpected error during scoring. Setting all as 0. {e}") scores = [0.0 for _ in range(len(sequences_str))] - data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device) + data.batch["acc"] = torch.tensor( + scores, dtype=torch.float32, device=prompt_ids.device + ) return scores def __call__(self, data: DataProto, return_dict: bool = False): @@ -163,8 +196,12 @@ def __call__(self, data: DataProto, return_dict: bool = False): prompt_length = prompt_ids.shape[-1] response_ids = data.batch["responses"] - valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1) - sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) + valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum( + dim=-1 + ) + sequences_str = self.tokenizer.batch_decode( + response_ids, skip_special_tokens=True + ) data_sources = data.non_tensor_batch["data_source"] scores = self.verify(data) diff --git a/Agent0/executor_train/verl/verl/workers/reward_model/megatron/reward_model.py b/Agent0/executor_train/verl/verl/workers/reward_model/megatron/reward_model.py index 01b1324..3e1015b 100644 --- a/Agent0/executor_train/verl/verl/workers/reward_model/megatron/reward_model.py +++ b/Agent0/executor_train/verl/verl/workers/reward_model/megatron/reward_model.py @@ -67,7 +67,11 @@ def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto: input_ids = data.batch["input_ids"] # (bs, seq_len) attention_mask = data.batch["attention_mask"] position_ids = data.batch["position_ids"] - ori_values = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} + ori_values = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } _, ori_seqlen = input_ids.size(0), input_ids.size(1) input_ids_for_rm = [] attention_mask_for_rm = [] @@ -97,21 +101,27 @@ def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto: ) print_decode = False # 3. encode by rm_tokenizer - rm_input_ids = self.rm_tokenizer(decode_with_rm_chat, return_tensors="pt")["input_ids"][0].to( - input_ids.device - ) + rm_input_ids = self.rm_tokenizer(decode_with_rm_chat, return_tensors="pt")[ + "input_ids" + ][0].to(input_ids.device) # 4. generate attention_mask and position_ids rm_attention_mask = torch.ones_like(rm_input_ids, device=input_ids.device) cur_seqlen = rm_input_ids.shape[-1] # NOTE(gh): the later reward compute will process the shape (bs, seqlen_pad_128) if cur_seqlen > ori_seqlen: - print(f"warninig: rm encode seqlen {cur_seqlen} > sft encode seqlen {ori_seqlen}") + print( + f"warninig: rm encode seqlen {cur_seqlen} > sft encode seqlen {ori_seqlen}" + ) rm_input_ids = rm_input_ids[:ori_seqlen] rm_attention_mask = rm_attention_mask[:ori_seqlen] else: # right padding - rm_input_ids = pad_sequence_to_length(rm_input_ids, ori_seqlen, self.rm_tokenizer.pad_token_id) - rm_attention_mask = pad_sequence_to_length(rm_attention_mask, ori_seqlen, 0) + rm_input_ids = pad_sequence_to_length( + rm_input_ids, ori_seqlen, self.rm_tokenizer.pad_token_id + ) + rm_attention_mask = pad_sequence_to_length( + rm_attention_mask, ori_seqlen, 0 + ) rm_position_ids = torch.arange(0, ori_seqlen, device=input_ids.device) input_ids_for_rm.append(torch.unsqueeze(rm_input_ids, dim=0)) attention_mask_for_rm.append(torch.unsqueeze(rm_attention_mask, dim=0)) @@ -142,9 +152,13 @@ def compute_reward(self, data: DataProto) -> DataProto: use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) micro_batch_size = data.meta_info.get("micro_batch_size", None) max_token_len = data.meta_info.get("max_token_len", None) - assert micro_batch_size is not None, "micro batch size is needed for forward compute" + assert ( + micro_batch_size is not None + ), "micro batch size is needed for forward compute" if use_dynamic_bsz: - assert max_token_len is not None, "use_dynamic_bsz is True, but max_token_len is None!" + assert ( + max_token_len is not None + ), "use_dynamic_bsz is True, but max_token_len is None!" max_token_len = max_token_len * self.config.megatron.context_parallel_size responses = data.batch["responses"] @@ -153,15 +167,22 @@ def compute_reward(self, data: DataProto) -> DataProto: with torch.no_grad(): output = self.forward_batch( - data, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len + data, + use_dynamic_bsz=use_dynamic_bsz, + micro_batch_size=micro_batch_size, + max_token_len=max_token_len, ) if mpu.is_pipeline_last_stage(ignore_virtual=True): logits = torch.cat(output["output"], dim=0) if use_dynamic_bsz: indices = output["indices"] indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == logits.size(0), f"{len(indices)} vs. {logits.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + assert len(indices) == logits.size( + 0 + ), f"{len(indices)} vs. {logits.size()}" + revert_indices = torch.tensor( + get_reverse_idx(indices), dtype=torch.long + ) logits = logits[revert_indices] else: logits = torch.empty( @@ -190,7 +211,9 @@ def compute_reward(self, data: DataProto) -> DataProto: attention_mask = ori_values["attention_mask"] position_ids = ori_values["position_ids"] - token_level_rewards = rewards.expand(attention_mask.shape[0], attention_mask.shape[1]) # (bs, ori_seqlen) + token_level_rewards = rewards.expand( + attention_mask.shape[0], attention_mask.shape[1] + ) # (bs, ori_seqlen) # assign last valid token reward to ori position if position_ids.dim() == 3: # qwen2vl mrope [bs, 3, seq_len] @@ -208,11 +231,19 @@ def compute_reward(self, data: DataProto) -> DataProto: # add empty cache after each compute get_torch_device().empty_cache() - batch = TensorDict({"rm_scores": token_level_rewards}, batch_size=input_ids.shape[0]) + batch = TensorDict( + {"rm_scores": token_level_rewards}, batch_size=input_ids.shape[0] + ) return DataProto(batch=batch) - def forward_batch(self, data: DataProto, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None): + def forward_batch( + self, + data: DataProto, + use_dynamic_bsz=False, + micro_batch_size=None, + max_token_len=None, + ): """ We assume: - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input @@ -230,35 +261,49 @@ def forward_batch(self, data: DataProto, use_dynamic_bsz=False, micro_batch_size mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) - self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() + self.has_multi_modal_inputs = ( + "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() + ) if self.has_multi_modal_inputs: - mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"] + mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch[ + "multi_modal_inputs" + ] mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor( list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"]))) ).to(torch.int64) indices = None if use_dynamic_bsz: - assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + assert ( + max_token_len is not None + ), "max_token_len must be set when use_dynamic_bsz is True" vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: - microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage + microbatch_group_size_per_vp_stage = ( + self.tf_config.microbatch_group_size_per_vp_stage + ) micro_batches, indices = rearrange_micro_batches( batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len, ) - assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( + assert ( + len(micro_batches) + % self.tf_config.microbatch_group_size_per_vp_stage + == 0 + ), ( f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " f"{microbatch_group_size_per_vp_stage} for megatron backend" ) else: - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) + micro_batches, indices = rearrange_micro_batches( + batch=mini_batch.batch, max_token_len=max_token_len + ) total_seqlen = max_token_len else: - assert micro_batch_size is not None, ( - "micro_batch_size is needed to be passed in when not using dynamic batch size" - ) + assert ( + micro_batch_size is not None + ), "micro_batch_size is needed to be passed in when not using dynamic batch size" micro_batches = mini_batch.batch.split(micro_batch_size) seq_len = micro_batches[0]["input_ids"].shape[1] total_seqlen = micro_batch_size * seq_len @@ -283,7 +328,11 @@ def forward_step(batch_iter, model): if "multi_modal_inputs" in batch: for key in batch["multi_modal_inputs"][0].keys(): multi_modal_inputs[key] = torch.cat( - [batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0 + [ + batch["multi_modal_inputs"][i][key] + for i in batch["multi_modal_inputs_idx"] + ], + dim=0, ) output = forward_fn( @@ -299,7 +348,9 @@ def forward_step(batch_iter, model): return output, loss_func # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.reward_model_module)) + batch_generator = make_batch_generator( + micro_batches, vpp_size=len(self.reward_model_module) + ) # TODO: we may use the new schedule instead # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) diff --git a/Agent0/executor_train/verl/verl/workers/rollout/async_server.py b/Agent0/executor_train/verl/verl/workers/rollout/async_server.py index da59c37..d87eff2 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/async_server.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/async_server.py @@ -57,14 +57,20 @@ async def lifespan(app: fastapi.FastAPI): # There's no way to gracefully restart uvicorn server if port is already in use, # so we exit the process directly and let AsyncLLMServerManager restart it. - print("FastAPI shutdown, maybe address already in use, exit process immediately.") + print( + "FastAPI shutdown, maybe address already in use, exit process immediately." + ) os._exit(-1) app = fastapi.FastAPI(lifespan=lifespan) - app.router.add_api_route("/v1/chat/completions", self.chat_completion, methods=["POST"]) + app.router.add_api_route( + "/v1/chat/completions", self.chat_completion, methods=["POST"] + ) self.port = _get_free_port() - config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") + config = uvicorn.Config( + app, host=["::", "0.0.0.0"], port=self.port, log_level="warning" + ) server = uvicorn.Server(config) await server.serve() @@ -82,7 +88,9 @@ async def chat_completion(self, raw_request: Request): raise NotImplementedError @abstractmethod - async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: + async def generate( + self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str + ) -> list[int]: """Generate response ids given prompt ids. Args: @@ -128,7 +136,9 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size - register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center") + register_center = ray.get_actor( + f"{self.worker_group.name_prefix}_register_center" + ) workers_info = ray.get(register_center.get_worker_info.remote()) assert len(workers_info) == self.worker_group.world_size @@ -155,7 +165,12 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): soft=False, ), name=f"async_llm_server_{rollout_dp_rank}", - ).remote(config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) + ).remote( + config, + self.rollout_dp_size, + rollout_dp_rank, + self.worker_group.name_prefix, + ) for rollout_dp_rank in unready_dp_ranks } @@ -167,7 +182,9 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): unready_dp_ranks.remove(rollout_dp_rank) except Exception: ray.kill(server) - print(f"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting...") + print( + f"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting..." + ) # All server instances are ready, init AsyncLLM engine. ray.get([server.init_engine.remote() for server in self.async_llm_servers]) @@ -177,7 +194,9 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): self.chat_scheduler_exception: Exception = None self.chat_scheduler_loop = None self.chat_scheduler_ready = threading.Event() - self.chat_scheduler_thread = threading.Thread(target=self._init_chat_scheduler, daemon=True) + self.chat_scheduler_thread = threading.Thread( + target=self._init_chat_scheduler, daemon=True + ) self.chat_scheduler_thread.start() self.chat_scheduler_ready.wait() @@ -233,13 +252,16 @@ def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto assert self.chat_scheduler is not None, "chat scheduler is not initialized." future = asyncio.run_coroutine_threadsafe( - self.chat_scheduler.generate_sequences(prompts, **sampling_params), self.chat_scheduler_loop + self.chat_scheduler.generate_sequences(prompts, **sampling_params), + self.chat_scheduler_loop, ) return future.result() def async_server_class( - rollout_backend: str, rollout_backend_module: Optional[str] = None, rollout_backend_class: Optional[str] = None + rollout_backend: str, + rollout_backend_module: Optional[str] = None, + rollout_backend_class: Optional[str] = None, ) -> type[AsyncServerBase]: """Get async server class. @@ -257,18 +279,26 @@ def async_server_class( # importlib.import_module and from ... import ... have subtle differences in ray if rollout_backend == "vllm": - from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer + from verl.workers.rollout.vllm_rollout.vllm_async_server import ( + AsyncvLLMServer, + ) return AsyncvLLMServer elif rollout_backend == "sglang": - from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer + from verl.workers.rollout.sglang_rollout.async_sglang_server import ( + AsyncSglangServer, + ) return AsyncSglangServer else: - raise NotImplementedError(f"rollout backend {rollout_backend} is not supported") + raise NotImplementedError( + f"rollout backend {rollout_backend} is not supported" + ) if rollout_backend_module is None or rollout_backend_class is None: - raise ValueError("rollout_backend_module and rollout_backend_class must be both provided for customization") + raise ValueError( + "rollout_backend_module and rollout_backend_class must be both provided for customization" + ) from verl.utils.import_utils import load_extern_type diff --git a/Agent0/executor_train/verl/verl/workers/rollout/chat_scheduler.py b/Agent0/executor_train/verl/verl/workers/rollout/chat_scheduler.py index 268c82d..e77aa2e 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/chat_scheduler.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/chat_scheduler.py @@ -46,11 +46,18 @@ def __init__(self, config: DictConfig, scheduler: "ChatCompletionScheduler"): self.scheduler = scheduler # Initialize tools from config file - self.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns + self.max_assistant_turns = ( + config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns + ) tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path - tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] + tool_list = ( + initialize_tools_from_config(tool_config_path) if tool_config_path else [] + ) self.tools = {tool.name: tool for tool in tool_list} - self._tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] + self._tool_schemas = [ + tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) + for tool in tool_list + ] print(f"Initialized tools: {self.tools}", flush=True) local_path = copy_to_local(config.actor_rollout_ref.model.path) @@ -67,7 +74,12 @@ def extra_body(self) -> dict[str, Any]: return None @abstractmethod - async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]): + async def __call__( + self, + messages: list[dict[str, str]], + completions: ChatCompletion, + info: dict[str, Any], + ): """Call back function to process completions. Args: @@ -78,7 +90,9 @@ async def __call__(self, messages: list[dict[str, str]], completions: ChatComple raise NotImplementedError @abstractmethod - def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, str]]], n: int) -> DataProto: + def postprocess( + self, batch: DataProto, batch_conversations: list[list[dict[str, str]]], n: int + ) -> DataProto: """Post process batch data. Args: @@ -101,8 +115,15 @@ def __init__(self, config: DictConfig, scheduler: "ChatCompletionScheduler"): # TODO: add reward manager to calculate reward score once a sample finish - async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]): - message = completions.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) + async def __call__( + self, + messages: list[dict[str, str]], + completions: ChatCompletion, + info: dict[str, Any], + ): + message = completions.choices[0].message.model_dump( + exclude_unset=True, exclude_none=True + ) if "content" not in message: message["content"] = "" messages.append(message) @@ -110,17 +131,23 @@ async def __call__(self, messages: list[dict[str, str]], completions: ChatComple # STEP 0: check if we reach max turns if self.max_assistant_turns and len(messages) >= self.max_assistant_turns: - print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Reach max turns, done!") + print( + f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Reach max turns, done!" + ) return # STEP 1: check if the model called tools if finish_reason != "tool_calls": - print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] No tool called, done!") + print( + f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] No tool called, done!" + ) return # STEP 2: call tools tool_calls = completions.choices[0].message.tool_calls - print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Call {len(tool_calls)} tools") + print( + f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Call {len(tool_calls)} tools" + ) tasks = [] for tool_call in tool_calls: tasks.append(self._call_tool(tool_call)) @@ -134,7 +161,9 @@ async def __call__(self, messages: list[dict[str, str]], completions: ChatComple messages.extend(tool_responses) # STEP 3: resubmit completion request with tool responses - self.scheduler.submit_chat_completions(messages=messages, request_id=completions.id, info=info) + self.scheduler.submit_chat_completions( + messages=messages, request_id=completions.id, info=info + ) async def _call_tool(self, tool_call) -> dict[str, str]: """Call tool and return tool response.""" @@ -144,7 +173,9 @@ async def _call_tool(self, tool_call) -> dict[str, str]: instance_id = await tool.create() try: - tool_response, tool_reward_score, tool_metrics = await tool.execute(instance_id, tool_args) + tool_response, tool_reward_score, tool_metrics = await tool.execute( + instance_id, tool_args + ) except Exception as e: logger.exception(f"Error when executing tool: {e}") return e @@ -157,7 +188,9 @@ async def _call_tool(self, tool_call) -> dict[str, str]: "tool_call_id": tool_call.id, } - def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, str]]], n: int) -> DataProto: + def postprocess( + self, batch: DataProto, batch_conversations: list[list[dict[str, str]]], n: int + ) -> DataProto: # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py # prompts: left pad # responses: right pad @@ -168,7 +201,10 @@ def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, # prompts: [prompt] from input dataset prompts = [ self.tokenizer.apply_chat_template( - prompt, tools=self.tool_schemas, add_generation_prompt=True, tokenize=False + prompt, + tools=self.tool_schemas, + add_generation_prompt=True, + tokenize=False, ) for prompt in batch.non_tensor_batch["raw_prompt"] ] @@ -177,19 +213,30 @@ def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, # sequences: [prompt + response] sequences = [ self.tokenizer.apply_chat_template( - conversation, tools=self.tool_schemas, add_generation_prompt=False, tokenize=False + conversation, + tools=self.tool_schemas, + add_generation_prompt=False, + tokenize=False, ) for conversation in batch_conversations ] # responses: [response] - responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)] + responses = [ + sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences) + ] - prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left") - responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right") + prompts = self.tokenizer( + prompts, return_tensors="pt", padding="longest", padding_side="left" + ) + responses = self.tokenizer( + responses, return_tensors="pt", padding="longest", padding_side="right" + ) if n > 1: prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0) - prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0) + prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave( + n, dim=0 + ) # response_mask: response mask with tools calling masked out response_mask = self._mask_out_tools_calling_tokens( @@ -200,7 +247,9 @@ def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, ) input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1) - attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1) + attention_mask = torch.cat( + [prompts["attention_mask"], responses["attention_mask"]], dim=1 + ) position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask batch = TensorDict( @@ -215,7 +264,9 @@ def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, batch_size=len(input_ids), ) - num_turns = np.array([len(conversation) for conversation in batch_conversations], dtype=np.int32) + num_turns = np.array( + [len(conversation) for conversation in batch_conversations], dtype=np.int32 + ) return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}) def _mask_out_tools_calling_tokens( @@ -238,7 +289,9 @@ def _mask_out_tools_calling_tokens( """ batch_size = input_ids.size(0) assert len(raw_prompts) == batch_size, f"{len(raw_prompts)} != {batch_size}" - assert len(batch_conversations) == batch_size, f"{len(batch_conversations)} != {batch_size}" + assert ( + len(batch_conversations) == batch_size + ), f"{len(batch_conversations)} != {batch_size}" # Deduplicate adjacent tool calls, since they're merged into one turn. # [user, assistant, tool, tool, assistant] -> [user, assistant, tool, assistant] @@ -257,9 +310,16 @@ def deduplicate_adjacent_tool_calls(roles): responses = batch_conversations[i][len(raw_prompts[i]) :] assert len(responses) > 0, f"responses is empty: {responses}" - roles = deduplicate_adjacent_tool_calls([response["role"] for response in responses]) + roles = deduplicate_adjacent_tool_calls( + [response["role"] for response in responses] + ) # Each turn should be: [BOS]...[EOS] - eos_indices = input_ids[i].eq(self.tokenizer.eos_token_id).nonzero().squeeze(1)[: len(roles)] + eos_indices = ( + input_ids[i] + .eq(self.tokenizer.eos_token_id) + .nonzero() + .squeeze(1)[: len(roles)] + ) for j in range(len(roles)): if roles[j] == "tool": bos = eos_indices[j - 1] + 1 if j > 0 else 0 @@ -299,11 +359,15 @@ def __init__( self.completion_callback = ToolCompletionCallback(config, self) logger.warning("completion_callback is None, use ToolCompletionCallback") else: - module_path, class_name = self.config.multi_turn.completion_callback.rsplit(".", 1) + module_path, class_name = self.config.multi_turn.completion_callback.rsplit( + ".", 1 + ) module = importlib.import_module(module_path) self.completion_callback = getattr(module, class_name)(config, self) - def submit_chat_completions(self, *, messages: list[dict[str, str]], request_id: str, info: dict[str, Any]): + def submit_chat_completions( + self, *, messages: list[dict[str, str]], request_id: str, info: dict[str, Any] + ): """Submit chat completion request without wait, completion_callback will be called when the request is done. Args: @@ -312,7 +376,9 @@ def submit_chat_completions(self, *, messages: list[dict[str, str]], request_id: info: Any other auxiliary information pass across multi-turn. """ info["__depth__"] += 1 - task = asyncio.create_task(self._submit_chat_completions_and_callback(messages, request_id, info)) + task = asyncio.create_task( + self._submit_chat_completions_and_callback(messages, request_id, info) + ) # โ€œfire-and-forgetโ€ background tasks self.background_tasks.add(task) @@ -367,11 +433,20 @@ async def _submit_chat_completions_and_callback( if info["__depth__"] == 0: info["__done__"].set() - async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: - client = AsyncOpenAI(base_url=f"http://{address}/v1", api_key="token-abc123", timeout=None, max_retries=0) + async def _chat_completions_openai( + self, address: str, **chat_complete_request + ) -> ChatCompletion: + client = AsyncOpenAI( + base_url=f"http://{address}/v1", + api_key="token-abc123", + timeout=None, + max_retries=0, + ) return await client.chat.completions.create(**chat_complete_request) - async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: + async def _chat_completions_aiohttp( + self, address: str, **chat_complete_request + ) -> ChatCompletion: try: extra_body = chat_complete_request.pop("extra_body", {}) chat_complete_request.update(extra_body or {}) @@ -407,7 +482,9 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: # validation dataset has already been repeated in `PPOTrainer._validate`. n = 1 if batch.meta_info.get("validate", False) else self.config.n tasks, batch_conversations = [], [None] * len(batch) * n - for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0)): + for batch_index, conversation in enumerate( + batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0) + ): # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] batch_conversations[batch_index] = conversation.tolist() @@ -422,13 +499,18 @@ async def generate_sequences(self, batch: DataProto) -> DataProto: ) await asyncio.gather(*tasks) - output_batch = self.completion_callback.postprocess(batch, batch_conversations, n=n) + output_batch = self.completion_callback.postprocess( + batch, batch_conversations, n=n + ) output_batch.meta_info["timing"] = {"generate_sequences": time.time() - t_start} print("[ChatCompletionScheduler] generate_sequences done") return output_batch async def _submit_chat_completions_semaphore( - self, messages: list[dict[str, str]], request_id: str, sampling_params: dict[str, Any] + self, + messages: list[dict[str, str]], + request_id: str, + sampling_params: dict[str, Any], ): done = asyncio.Event() @@ -438,7 +520,9 @@ async def _submit_chat_completions_semaphore( "__sampling_params__": sampling_params, } - self.submit_chat_completions(messages=messages, request_id=request_id, info=info) + self.submit_chat_completions( + messages=messages, request_id=request_id, info=info + ) # Wait until all completion requests are done await done.wait() diff --git a/Agent0/executor_train/verl/verl/workers/rollout/hf_rollout.py b/Agent0/executor_train/verl/verl/workers/rollout/hf_rollout.py index 32d0bc8..9361e15 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/hf_rollout.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/hf_rollout.py @@ -44,7 +44,9 @@ def __init__(self, module: nn.Module, config): def generate_sequences(self, prompts: DataProto) -> DataProto: batch_size = prompts.batch.batch_size[0] - num_chunks = max(batch_size // self.config.get("micro_batch_size", batch_size), 1) + num_chunks = max( + batch_size // self.config.get("micro_batch_size", batch_size), 1 + ) batch_prompts = prompts.chunk(chunks=num_chunks) output = [self._generate_minibatch(p) for p in batch_prompts] output = DataProto.concat(output) @@ -57,9 +59,13 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: is_validate = prompts.meta_info.get("validate", False) temperature = prompts.meta_info.get("temperature", self.config.temperature) - response_length = prompts.meta_info.get("response_length", self.config.response_length) + response_length = prompts.meta_info.get( + "response_length", self.config.response_length + ) top_p = prompts.meta_info.get("top_p", self.config.get("top_p", 1.0)) - top_k = max(0, prompts.meta_info.get("top_k", self.config.get("top_k", 0))) # to be compatible with vllm + top_k = max( + 0, prompts.meta_info.get("top_k", self.config.get("top_k", 0)) + ) # to be compatible with vllm if not do_sample: # do_sample==False -> greedy decoding @@ -72,7 +78,9 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: kwargs = { "do_sample": True, "num_beams": 1, - "top_k": max(0, self.config.val_kwargs.top_k), # to be compatible with vllm + "top_k": max( + 0, self.config.val_kwargs.top_k + ), # to be compatible with vllm "top_p": self.config.val_kwargs.top_p, "temperature": self.config.val_kwargs.temperature, "num_return_sequences": 1, # if validate, already repeat in ray_trainer @@ -105,8 +113,12 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: if isinstance(self.module, FSDP): # recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069 - param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False) - with param_ctx, torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): + param_ctx = FSDP.summon_full_params( + self.module, writeback=False, recurse=False + ) + with param_ctx, torch.autocast( + device_type=get_device_name(), dtype=torch.bfloat16 + ): output = self.module.generate( input_ids=idx, attention_mask=attention_mask, @@ -131,7 +143,11 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: delta_length = sequence_length - seq.shape[1] if delta_length > 0: - delta_tokens = torch.ones(size=(generated_batch_size, delta_length), device=seq.device, dtype=seq.dtype) + delta_tokens = torch.ones( + size=(generated_batch_size, delta_length), + device=seq.device, + dtype=seq.dtype, + ) delta_tokens = pad_token_id * delta_tokens seq = torch.cat((seq, delta_tokens), dim=1) assert seq.shape[1] == sequence_length @@ -140,14 +156,20 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: num_return_sequences = kwargs.get("num_return_sequences", 1) if num_return_sequences > 1: position_ids = position_ids.repeat_interleave(num_return_sequences, dim=0) - attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) + attention_mask = attention_mask.repeat_interleave( + num_return_sequences, dim=0 + ) prompt = seq[:, :prompt_length] # (generated_batch_size, prompt_length) response = seq[:, prompt_length:] # (generated_batch_size, response_length) response_length = response.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(generated_batch_size, 1) + delta_position_id = torch.arange( + 1, response_length + 1, device=position_ids.device + ) + delta_position_id = delta_position_id.unsqueeze(0).repeat( + generated_batch_size, 1 + ) response_position_ids = position_ids[:, -1:] + delta_position_id position_ids = torch.cat([position_ids, response_position_ids], dim=-1) diff --git a/Agent0/executor_train/verl/verl/workers/rollout/naive/naive_rollout.py b/Agent0/executor_train/verl/verl/workers/rollout/naive/naive_rollout.py index fe56dc4..19446a0 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/naive/naive_rollout.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/naive/naive_rollout.py @@ -62,7 +62,11 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: self.module.eval() - prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device) + prev_attention_mask = torch.ones( + size=(batch_size, 1), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) logits_lst = [] for _ in range(self.config.response_length): @@ -71,7 +75,11 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: idx_cond = idx # forward the model to get the logits for the index in the sequence # we use huggingface APIs here - output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids) + output = self.module( + input_ids=idx_cond, + attention_mask=attention_mask, + position_ids=position_ids, + ) logits = output.logits # pluck the logits at the final step and scale by desired temperature logits = logits[:, -1, :] / self.config.temperature # (bs, vocab_size) @@ -90,7 +98,9 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1) for token_id in eos_token_id: - prev_attention_mask = torch.logical_and(idx_next != token_id, prev_attention_mask.bool()) + prev_attention_mask = torch.logical_and( + idx_next != token_id, prev_attention_mask.bool() + ) prev_attention_mask.to(attention_mask.dtype) position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1) diff --git a/Agent0/executor_train/verl/verl/workers/rollout/schemas.py b/Agent0/executor_train/verl/verl/workers/rollout/schemas.py index 99f860a..e2e5842 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/schemas.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/schemas.py @@ -122,11 +122,17 @@ class AsyncRolloutRequest(BaseModel): @classmethod def initialize_request(cls, values): if not (messages := values.get("messages")): - raise ValueError("messages is required for AsyncRolloutRequest initialization") + raise ValueError( + "messages is required for AsyncRolloutRequest initialization" + ) if not (max_prompt_len := values.get("max_prompt_len")): - raise ValueError("max_prompt_len is required for AsyncRolloutRequest initialization") + raise ValueError( + "max_prompt_len is required for AsyncRolloutRequest initialization" + ) if not (processing_class := values.pop("processing_class", None)): - raise ValueError("processing_class is required for AsyncRolloutRequest initialization") + raise ValueError( + "processing_class is required for AsyncRolloutRequest initialization" + ) values["messages"] = [Message.model_validate(msg) for msg in messages] @@ -144,7 +150,9 @@ def initialize_request(cls, values): values["multi_modal_inputs"] = {} tools = ( - [tool.model_dump() for tool in tool_schemas] if (tool_schemas := values.get("tool_schemas", [])) else None + [tool.model_dump() for tool in tool_schemas] + if (tool_schemas := values.get("tool_schemas", [])) + else None ) multi_modal_data = values["multi_modal_data"] @@ -189,13 +197,25 @@ def initialize_request(cls, values): multi_modal_inputs.pop("attention_mask", None) values["multi_modal_inputs"] = multi_modal_inputs - values["position_ids"] = values["prompt_position_ids"] = cls._get_position_ids( - processing_class, values["input_ids"], values["attention_mask"], multi_modal_inputs + values["position_ids"] = values["prompt_position_ids"] = ( + cls._get_position_ids( + processing_class, + values["input_ids"], + values["attention_mask"], + multi_modal_inputs, + ) ) - values["prompt_ids"], values["prompt_attention_mask"] = values["input_ids"], values["attention_mask"] - values["loss_mask"] = values["prompt_loss_mask"] = torch.zeros_like(values["input_ids"], dtype=torch.bool) - values["generation_prompt_ids"] = values["input_ids"][..., tokens_without_prompt.shape[-1] :] + values["prompt_ids"], values["prompt_attention_mask"] = ( + values["input_ids"], + values["attention_mask"], + ) + values["loss_mask"] = values["prompt_loss_mask"] = torch.zeros_like( + values["input_ids"], dtype=torch.bool + ) + values["generation_prompt_ids"] = values["input_ids"][ + ..., tokens_without_prompt.shape[-1] : + ] values["base_conv_wo_gen_prompt_end_pos"] = cls._handle_apply_chat_template( processing_class, BASE_CHAT_HISTORY, @@ -218,7 +238,9 @@ def initialize_request(cls, values): @staticmethod def _handle_apply_chat_template( - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), messages: list[Message], multi_modal_data: dict[str, Any], tools: Optional[list[OpenAIFunctionToolSchema]] = None, @@ -227,12 +249,17 @@ def _handle_apply_chat_template( return_dict: bool = False, ): raw_prompt = processing_class.apply_chat_template( - messages, tools=tools, add_generation_prompt=add_generation_prompt, tokenize=False + messages, + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False, ) if not tokenize: return raw_prompt - if isinstance(processing_class, PreTrainedTokenizer) or isinstance(processing_class, PreTrainedTokenizerFast): + if isinstance(processing_class, PreTrainedTokenizer) or isinstance( + processing_class, PreTrainedTokenizerFast + ): if any(len(values) > 0 for values in multi_modal_data.values()): logger.warning( "There is multi_modal_data but you are not using a processor. Multi-modal data will be ignored." @@ -240,11 +267,19 @@ def _handle_apply_chat_template( model_inputs = processing_class(text=[raw_prompt], return_tensors="pt") elif isinstance(processing_class, ProcessorMixin): # When we update multi_model_keys, we also need to update this logic - images = images if len(images := multi_modal_data.get("image", [])) > 0 else None - videos = videos if len(videos := multi_modal_data.get("video", [])) > 0 else None - model_inputs = processing_class(text=[raw_prompt], images=images, videos=videos, return_tensors="pt") + images = ( + images if len(images := multi_modal_data.get("image", [])) > 0 else None + ) + videos = ( + videos if len(videos := multi_modal_data.get("video", [])) > 0 else None + ) + model_inputs = processing_class( + text=[raw_prompt], images=images, videos=videos, return_tensors="pt" + ) else: - raise ValueError(f"Unsupported processing class type: {type(processing_class)}") + raise ValueError( + f"Unsupported processing class type: {type(processing_class)}" + ) model_inputs = dict(model_inputs) if return_dict: @@ -254,7 +289,9 @@ def _handle_apply_chat_template( @staticmethod def _get_position_ids( - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), input_ids: torch.Tensor, attention_mask: torch.Tensor, multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None, @@ -262,7 +299,8 @@ def _get_position_ids( # special case for qwen2vl is_qwen2vl = ( hasattr(processing_class, "image_processor") - and "Qwen2VLImageProcessor" in processing_class.image_processor.__class__.__name__ + and "Qwen2VLImageProcessor" + in processing_class.image_processor.__class__.__name__ ) if is_qwen2vl: from verl.models.transformers.qwen2_vl import get_rope_index @@ -273,12 +311,12 @@ def _get_position_ids( video_grid_thw = multi_modal_inputs.get("video_grid_thw") second_per_grid_ts = multi_modal_inputs.get("second_per_grid_ts") - assert input_ids.dim() == 2 and input_ids.shape[0] == 1, ( - f"input_ids should be 2D with batch size 1, but got shape {input_ids.shape}" - ) - assert attention_mask.dim() == 2 and attention_mask.shape[0] == 1, ( - f"attention_mask should be 2D with batch size 1, but got shape {attention_mask.shape}" - ) + assert ( + input_ids.dim() == 2 and input_ids.shape[0] == 1 + ), f"input_ids should be 2D with batch size 1, but got shape {input_ids.shape}" + assert ( + attention_mask.dim() == 2 and attention_mask.shape[0] == 1 + ), f"attention_mask should be 2D with batch size 1, but got shape {attention_mask.shape}" new_position_ids = get_rope_index( processing_class, input_ids=input_ids.squeeze(0), @@ -293,7 +331,9 @@ def _get_position_ids( def _update_input_ids( self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), new_input_ids: torch.Tensor, attention_mask: bool, loss_mask: bool, @@ -328,7 +368,9 @@ def _update_input_ids( ), f"""Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}""" - def _update_multi_modal_inputs(self, new_multi_modal_inputs: dict[str, torch.Tensor]) -> None: + def _update_multi_modal_inputs( + self, new_multi_modal_inputs: dict[str, torch.Tensor] + ) -> None: """ Update the multi_modal_inputs of the request in additive manner. """ @@ -341,7 +383,10 @@ def _update_multi_modal_inputs(self, new_multi_modal_inputs: dict[str, torch.Ten ) def get_generation_prompt_ids( - self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + self, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), ) -> list[int]: """ Get the generation prompt ids for rollout engine. @@ -350,15 +395,26 @@ def get_generation_prompt_ids( """ generation_prompt_ids = ( None - if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all() + if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :] + .eq(self.generation_prompt_ids) + .all() else self.generation_prompt_ids ) if generation_prompt_ids is not None: - self._update_input_ids(processing_class, generation_prompt_ids, attention_mask=True, loss_mask=False) + self._update_input_ids( + processing_class, + generation_prompt_ids, + attention_mask=True, + loss_mask=False, + ) if self.use_inference_chat_template: messages = [msg.model_dump() for msg in self.messages] - tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + tools = ( + [tool.model_dump() for tool in self.tool_schemas] + if self.tool_schemas + else None + ) generation_prompt_ids = self._handle_apply_chat_template( processing_class, messages, @@ -373,41 +429,71 @@ def get_generation_prompt_ids( def add_user_message( self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), content: str, ) -> None: self.messages.append(Message(role="user", content=content)) messages = [*BASE_CHAT_HISTORY, self.messages[-1]] - tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + tools = ( + [tool.model_dump() for tool in self.tool_schemas] + if self.tool_schemas + else None + ) # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine # Inference, it is pure text. content_ids = self._handle_apply_chat_template( - processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True + processing_class, + messages, + multi_modal_data={}, + tools=tools, + add_generation_prompt=False, + tokenize=True, )[..., self.base_conv_wo_gen_prompt_end_pos :] - self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=False) + self._update_input_ids( + processing_class, content_ids, attention_mask=True, loss_mask=False + ) def add_assistant_message( self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), content: str, tool_calls: Optional[list[OpenAIFunctionToolCall]] = None, ) -> None: - self.messages.append(Message(role="assistant", content=content, tool_calls=tool_calls)) + self.messages.append( + Message(role="assistant", content=content, tool_calls=tool_calls) + ) messages = [*BASE_CHAT_HISTORY, self.messages[-1]] - tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + tools = ( + [tool.model_dump() for tool in self.tool_schemas] + if self.tool_schemas + else None + ) # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine # Inference, it is pure text. content_ids = self._handle_apply_chat_template( - processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True + processing_class, + messages, + multi_modal_data={}, + tools=tools, + add_generation_prompt=False, + tokenize=True, )[..., self.base_conv_with_gen_prompt_end_pos :] - self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=True) + self._update_input_ids( + processing_class, content_ids, attention_mask=True, loss_mask=True + ) def add_tool_response_messages( self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), contents: list[str | dict[str, Any]], ) -> None: if not contents: @@ -452,7 +538,11 @@ def add_tool_response_messages( self.messages.append(Message(role="tool", content=content)) messages = [*BASE_CHAT_HISTORY, *self.messages[-len(contents) :]] - tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + tools = ( + [tool.model_dump() for tool in self.tool_schemas] + if self.tool_schemas + else None + ) for key in self.multi_modal_keys: if len(delta_multi_modal_data[key]) > 0: @@ -468,7 +558,9 @@ def add_tool_response_messages( tokenize=True, return_dict=True, ) - content_ids = content_info["input_ids"][..., self.base_conv_wo_gen_prompt_end_pos :] + content_ids = content_info["input_ids"][ + ..., self.base_conv_wo_gen_prompt_end_pos : + ] # process multi_modal_inputs multi_modal_inputs = content_info.copy() @@ -492,7 +584,9 @@ def update_metrics(self, metrics: Any, tool_id: str) -> None: def _get_prompt_diffs( self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), full_prompt_ids: torch.Tensor, current_prompt_ids: torch.Tensor, diff_surrounding_chars: int = 10, @@ -524,8 +618,12 @@ def _get_prompt_diffs( """ full_prompt_ids = full_prompt_ids.squeeze(0) current_prompt_ids = current_prompt_ids.squeeze(0) - full_prompt = processing_class.decode(full_prompt_ids, skip_special_tokens=False) - current_prompt = processing_class.decode(current_prompt_ids, skip_special_tokens=False) + full_prompt = processing_class.decode( + full_prompt_ids, skip_special_tokens=False + ) + current_prompt = processing_class.decode( + current_prompt_ids, skip_special_tokens=False + ) s = difflib.SequenceMatcher(None, full_prompt, current_prompt, autojunk=False) diffs = [] for tag, i1, i2, j1, j2 in s.get_opcodes(): @@ -549,7 +647,9 @@ def _get_prompt_diffs( def finalize( self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), reward_scores: dict[str, list[float]], finish_reason_type: FinishReasonTypeEnum = FinishReasonTypeEnum.STOP, ) -> None: @@ -558,20 +658,39 @@ def finalize( # In case we failed to generate the assistant message and the generation prompt ids were already added to # input_ids, remove them from the end of input_ids - if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all(): - self.input_ids = self.input_ids[..., : -self.generation_prompt_ids.shape[-1]] - self.attention_mask = self.attention_mask[..., : -self.generation_prompt_ids.shape[-1]] - self.position_ids = self.position_ids[..., : -self.generation_prompt_ids.shape[-1]] - self.loss_mask = self.loss_mask[..., : -self.generation_prompt_ids.shape[-1]] + if ( + self.input_ids[..., -self.generation_prompt_ids.shape[-1] :] + .eq(self.generation_prompt_ids) + .all() + ): + self.input_ids = self.input_ids[ + ..., : -self.generation_prompt_ids.shape[-1] + ] + self.attention_mask = self.attention_mask[ + ..., : -self.generation_prompt_ids.shape[-1] + ] + self.position_ids = self.position_ids[ + ..., : -self.generation_prompt_ids.shape[-1] + ] + self.loss_mask = self.loss_mask[ + ..., : -self.generation_prompt_ids.shape[-1] + ] self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :] - if self.tokenization_sanity_check_mode != TokenizationSanityCheckModeEnum.DISABLE: + if ( + self.tokenization_sanity_check_mode + != TokenizationSanityCheckModeEnum.DISABLE + ): # When there is a diff, we log the diffs with diff_surrounding_chars context diff_surrounding_chars = 10 messages = [msg.model_dump() for msg in self.messages] - tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None + tools = ( + [tool.model_dump() for tool in self.tool_schemas] + if self.tool_schemas + else None + ) full_prompt_info = self._handle_apply_chat_template( processing_class, messages, @@ -609,14 +728,25 @@ def finalize( ) if diffs := self._get_prompt_diffs( - processing_class, full_prompt_ids, self.input_ids, diff_surrounding_chars=diff_surrounding_chars + processing_class, + full_prompt_ids, + self.input_ids, + diff_surrounding_chars=diff_surrounding_chars, ): log_warning = False - if self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.STRICT: + if ( + self.tokenization_sanity_check_mode + == TokenizationSanityCheckModeEnum.STRICT + ): log_warning = True - elif self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.IGNORE_STRIPPABLE: + elif ( + self.tokenization_sanity_check_mode + == TokenizationSanityCheckModeEnum.IGNORE_STRIPPABLE + ): non_strippable_diffs_exist = any( - d["full_prompt_chunk"].strip() or d["current_prompt_chunk"].strip() for d in diffs + d["full_prompt_chunk"].strip() + or d["current_prompt_chunk"].strip() + for d in diffs ) if non_strippable_diffs_exist: log_warning = True @@ -647,7 +777,9 @@ def finalize( elif finish_reason_type == FinishReasonTypeEnum.LENGTH: pass else: - raise ValueError(f"Unsupported finalize finish reason type: {finish_reason_type}") + raise ValueError( + f"Unsupported finalize finish reason type: {finish_reason_type}" + ) self.truncate_output_ids(processing_class) assert ( @@ -659,17 +791,24 @@ def finalize( {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}""" def truncate_output_ids( - self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + self, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), ) -> None: self.input_ids = self.input_ids[..., : self.max_model_len] self.attention_mask = self.attention_mask[..., : self.max_model_len] self.position_ids = self.position_ids[..., : self.max_model_len] self.loss_mask = self.loss_mask[..., : self.max_model_len] - self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :][..., : self.max_response_len] - self.response_attention_mask = self.attention_mask[..., self.prompt_attention_mask.shape[-1] :][ - ..., : self.max_response_len - ] - self.response_position_ids = self.position_ids[..., self.prompt_position_ids.shape[-1] :][ + self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :][ ..., : self.max_response_len ] - self.response_loss_mask = self.loss_mask[..., self.prompt_loss_mask.shape[-1] :][..., : self.max_response_len] + self.response_attention_mask = self.attention_mask[ + ..., self.prompt_attention_mask.shape[-1] : + ][..., : self.max_response_len] + self.response_position_ids = self.position_ids[ + ..., self.prompt_position_ids.shape[-1] : + ][..., : self.max_response_len] + self.response_loss_mask = self.loss_mask[ + ..., self.prompt_loss_mask.shape[-1] : + ][..., : self.max_response_len] diff --git a/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/async_sglang_server.py index df26765..eb88a2e 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -44,7 +44,9 @@ async def init_engine(self): return all_actors = ray.util.list_named_actors(all_namespaces=True) matched_actors = [ - actor for actor in all_actors if actor.get("name", None).startswith(self.wg_prefix + "WorkerDict_") + actor + for actor in all_actors + if actor.get("name", None).startswith(self.wg_prefix + "WorkerDict_") ] for matched_actor in matched_actors: @@ -52,10 +54,14 @@ async def init_engine(self): assert len(fields) == 2, f"invalid actor name: {matched_actor['name']}" pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1]) - if (self._dp_size * pg_index + local_rank) // self._tp_size == self._dp_rank: + if ( + self._dp_size * pg_index + local_rank + ) // self._tp_size == self._dp_rank: worker = ray.get_actor(**matched_actor) self.workers.append(worker) - if (self._dp_size * pg_index + local_rank) / self._tp_size == self._dp_rank: + if ( + self._dp_size * pg_index + local_rank + ) / self._tp_size == self._dp_rank: self.master_worker = worker async def chat_completion(self, raw_request: Request): @@ -66,8 +72,12 @@ async def chat_completion(self, raw_request: Request): [outputs] = await asyncio.gather(output_future) return JSONResponse(outputs) - async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: - return await self.master_worker.generate.remote(prompt_ids, sampling_params, request_id) + async def generate( + self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str + ) -> list[int]: + return await self.master_worker.generate.remote( + prompt_ids, sampling_params, request_id + ) async def wake_up(self): if not self.config.rollout.free_cache_engine: diff --git a/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 3c66943..8187ba7 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -54,10 +54,16 @@ from verl import DataProto from verl.interactions.base import BaseInteraction -from verl.interactions.utils.interaction_registry import initialize_interactions_from_config +from verl.interactions.utils.interaction_registry import ( + initialize_interactions_from_config, +) from verl.third_party.sglang import parallel_state as sglang_ps from verl.tools.base_tool import BaseTool -from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall +from verl.tools.schemas import ( + OpenAIFunctionCallSchema, + OpenAIFunctionParsedSchema, + OpenAIFunctionToolCall, +) from verl.tools.utils.tool_registry import initialize_tools_from_config from verl.utils.net_utils import is_ipv6 from verl.utils.profiler import GPUMemoryLogger @@ -170,7 +176,8 @@ async def update_weights_from_tensor( to avoid duplicated cache cleaning operation.""" obj = UpdateWeightsFromTensorReqInput( serialized_named_tensors=[ - MultiprocessingSerializer.serialize(named_tensors) for _ in range(self.server_args.tp_size) + MultiprocessingSerializer.serialize(named_tensors) + for _ in range(self.server_args.tp_size) ], load_format=load_format, flush_cache=flush_cache, @@ -188,7 +195,9 @@ def _pre_process_inputs( prompt_token_ids: torch.Tensor, ) -> torch.Tensor: # remove the left padding in the prompt token_id - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][ + 0 + ] return prompt_token_ids[non_pad_index:] @@ -202,12 +211,18 @@ def _post_process_outputs(processing_class, output): # This is when processing_class is a tokenizer tokenizer = processing_class except AttributeError as e: - raise ValueError(f"Cannot get tokenizer from processing_class {processing_class}") from e + raise ValueError( + f"Cannot get tokenizer from processing_class {processing_class}" + ) from e def _map_each_response(resp): output_token_logprobs = resp["meta_info"]["output_token_logprobs"] log_probs, output_token_ids = zip( - *[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs], strict=True + *[ + (log_prob, token_ids) + for log_prob, token_ids, _ in output_token_logprobs + ], + strict=True, ) return torch.tensor(output_token_ids), torch.tensor(log_probs) @@ -217,10 +232,18 @@ def _map_each_response(resp): for output_token_ids, log_probs in out_map: batched_output_token_ids.append(output_token_ids) batched_logprobs.append(log_probs) - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id - batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id) + pad_token_id = ( + tokenizer.pad_token_id + if tokenizer.pad_token_id is not None + else tokenizer.eos_token_id + ) + batched_output_token_ids = pad_sequence( + batched_output_token_ids, batch_first=True, padding_value=pad_token_id + ) if len(batched_logprobs) > 0: - batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id) + batched_logprobs = pad_sequence( + batched_logprobs, batch_first=True, padding_value=pad_token_id + ) return batched_output_token_ids, batched_logprobs @@ -238,14 +261,18 @@ def get_tool_call_parser_type( # This is when processing_class is a processor tokenizer_vocab = processing_class.tokenizer.get_vocab() except AttributeError as e: - raise ValueError(f"Cannot get vocab from processing_class {processing_class}") from e + raise ValueError( + f"Cannot get vocab from processing_class {processing_class}" + ) from e if parser.bot_token.strip() in tokenizer_vocab and ( parser.eot_token == "" or parser.eot_token.strip() in tokenizer_vocab ): return parser_type else: - raise ValueError(f"No tool call parser found for processing_class {processing_class}") + raise ValueError( + f"No tool call parser found for processing_class {processing_class}" + ) class SGLangRollout(BaseRollout): @@ -253,7 +280,9 @@ def __init__( self, actor_module: str, config: DictConfig, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, + processing_class: ( + PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin + ), model_hf_config, port=None, trust_remote_code: bool = False, @@ -294,7 +323,9 @@ def __init__( self._sgl_tools, self._function_call_parser, ) = self._initialize_tools(config, processing_class) - self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(config) + self.interaction_map: dict[str, BaseInteraction] = ( + self._initialize_interactions(config) + ) # If turn on `free_cache_engine`, SGLang engine's KV cache # will be freed after each `generate_sequences` call. logger.info( @@ -321,15 +352,17 @@ def __init__( # This is when processing_class is a processor self.pad_token_id = self.processing_class.tokenizer.pad_token_id except AttributeError as e: - raise ValueError(f"Cannot get pad_token_id from processing_class {self.processing_class}") from e + raise ValueError( + f"Cannot get pad_token_id from processing_class {self.processing_class}" + ) from e def _init_distributed_env(self, device_mesh_cpu, **kwargs): self._device_mesh_cpu = device_mesh_cpu os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") self.tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert self.tensor_parallel_size <= dist.get_world_size(), ( - "tensor parallel size should be less than or equal to the world size" - ) + assert ( + self.tensor_parallel_size <= dist.get_world_size() + ), "tensor parallel size should be less than or equal to the world size" self.train_tp = kwargs.get("train_tp", None) if self.train_tp is not None: # deployed with megatron @@ -358,39 +391,53 @@ def _init_distributed_env(self, device_mesh_cpu, **kwargs): self._tp_rank = self._device_mesh_cpu["tp"].get_local_rank() self._tp_size = self._device_mesh_cpu["tp"].size() if self._rank == 0: - logger.info(f"_init_distributed_env: :tp_world: {self._tp_size}, global_world: {world_size}") + logger.info( + f"_init_distributed_env: :tp_world: {self._tp_size}, global_world: {world_size}" + ) # get tp_rank of this process in this tp group visible_devices = [None] * self._device_mesh_cpu.size(1) torch.distributed.all_gather_object( - visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], self._device_mesh_cpu.get_group("tp") + visible_devices, + os.environ["CUDA_VISIBLE_DEVICES"], + self._device_mesh_cpu.get_group("tp"), ) self.visible_devices_set = set(",".join(visible_devices).split(",")) - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(self.visible_devices_set))) + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + sorted(list(self.visible_devices_set)) + ) def _verify_config(self, model_hf_config): if not self.config.get("max_model_len", None): - self.config.max_model_len = self.config.prompt_length + self.config.response_length + self.config.max_model_len = ( + self.config.prompt_length + self.config.response_length + ) assert ( - self.config.max_model_len >= self.config.prompt_length + self.config.response_length + self.config.max_model_len + >= self.config.prompt_length + self.config.response_length ), f"""max_model_len should be greater than total sequence length (prompt_length + response_length): {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}""" max_position_embeddings = None if hasattr(model_hf_config, "max_position_embeddings"): max_position_embeddings = model_hf_config.max_position_embeddings - elif hasattr(model_hf_config, "llm_config") and hasattr(model_hf_config.llm_config, "max_position_embeddings"): + elif hasattr(model_hf_config, "llm_config") and hasattr( + model_hf_config.llm_config, "max_position_embeddings" + ): max_position_embeddings = model_hf_config.llm_config.max_position_embeddings elif hasattr(model_hf_config, "text_config") and hasattr( model_hf_config.text_config, "max_position_embeddings" ): - max_position_embeddings = model_hf_config.text_config.max_position_embeddings + max_position_embeddings = ( + model_hf_config.text_config.max_position_embeddings + ) if max_position_embeddings is None: raise ValueError("max_position_embeddings not found in model_hf_config") rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) if not rope_scaling_config: - assert max_position_embeddings >= self.config.prompt_length + self.config.response_length, ( - "model context length should be greater than total sequence length" - ) + assert ( + max_position_embeddings + >= self.config.prompt_length + self.config.response_length + ), "model context length should be greater than total sequence length" else: # handle type where there's a length extend factor # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support @@ -429,7 +476,11 @@ def _init_inference_engine(self, trust_remote_code, actor_module, port): else: dist_init_addr = None - load_format = "dummy" if self.config.load_format.startswith("dummy") else self.config.load_format + load_format = ( + "dummy" + if self.config.load_format.startswith("dummy") + else self.config.load_format + ) tp_size_per_node = self._tp_size // nnodes node_rank = self._tp_rank // tp_size_per_node first_rank_in_node = self._tp_rank % tp_size_per_node == 0 @@ -517,7 +568,9 @@ def _initialize_tools(self, config, processing_class): tool_list = initialize_tools_from_config(tools_config_file) logger.info(f"Initialize tools from configuration.: tool_list: {tool_list}") - tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list] + tool_schemas = [ + tool.get_openai_tool_schema().model_dump() for tool in tool_list + ] tool_map = {tool.name: tool for tool in tool_list} tool_call_parser_type = get_tool_call_parser_type(processing_class) sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas] @@ -546,7 +599,9 @@ def _initialize_interactions(self, config): interaction_config_file = config.multi_turn.interaction_config_path interaction_map = initialize_interactions_from_config(interaction_config_file) - logger.info(f"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}") + logger.info( + f"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}" + ) return interaction_map @GPUMemoryLogger(role="sglang rollout", logger=logger) @@ -578,7 +633,9 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() - def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: + def _batch_level_generate_sequences( + self, prompts: DataProto, **kwargs + ) -> DataProto: """Generates single-turn sequences for a batch of prompts. For single-turn generation, all prompts are processed in one request. `_batch_level_generate_sequences` involves: @@ -635,7 +692,10 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP non_tensor_batch = prompts.non_tensor_batch if "raw_prompt_ids" not in non_tensor_batch: non_tensor_batch["raw_prompt_ids"] = np.array( - [_pre_process_inputs(self.pad_token_id, idx[i]).tolist() for i in range(batch_size)], + [ + _pre_process_inputs(self.pad_token_id, idx[i]).tolist() + for i in range(batch_size) + ], dtype=object, ) @@ -651,13 +711,16 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP "prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data, "image_data": ( - multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None + multi_modal_data.get("image", None) + if isinstance(multi_modal_data, dict) + else None ), } ) else: sglang_inputs = [ - {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") + {"prompt_token_ids": raw_prompt_ids} + for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") ] # Ensure token IDs are lists or numpy arrays @@ -671,7 +734,9 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP # Extract token IDs and image data for SGLang Engine idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs] - image_list = [input_data.get("image_data", None) for input_data in sglang_inputs] + image_list = [ + input_data.get("image_data", None) for input_data in sglang_inputs + ] do_sample = prompts.meta_info.get("do_sample", True) is_validate = prompts.meta_info.get("validate", False) @@ -739,7 +804,9 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP rollout_log_probs = out[1].to(idx.device) if response.shape[1] < self.config.response_length: - response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) + response = pad_sequence_to_length( + response, self.config.response_length, self.pad_token_id + ) if self.config.calculate_log_probs: rollout_log_probs = pad_sequence_to_length( rollout_log_probs, self.config.response_length, self.pad_token_id @@ -748,10 +815,14 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP seq = torch.cat([idx, response], dim=-1) response_length = response.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) + delta_position_id = torch.arange( + 1, response_length + 1, device=position_ids.device + ) delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) if position_ids.dim() == 3: # qwen2vl mrope - delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) + delta_position_id = delta_position_id.view(batch_size, 1, -1).expand( + batch_size, 3, -1 + ) # TODO(sgm): fix position_ids on right_pad # prompt: left pad + response: right pad @@ -846,25 +917,37 @@ async def _async_rollout_a_request( self._tool_map[tool_call.function.name].execute( _req.request_id, tool_call.function.arguments, - **_req.tools_kwargs[tool_call.function.name].get("execute_kwargs", {}), + **_req.tools_kwargs[tool_call.function.name].get( + "execute_kwargs", {} + ), ) for tool_call in parsed_tool_calls ] ) - _req.add_tool_response_messages(self.processing_class, [resp for resp, _, _ in tool_call_results]) - for tool_call, (resp, reward, metrics) in zip(parsed_tool_calls, tool_call_results, strict=True): + _req.add_tool_response_messages( + self.processing_class, + [resp for resp, _, _ in tool_call_results], + ) + for tool_call, (resp, reward, metrics) in zip( + parsed_tool_calls, tool_call_results, strict=True + ): _req.update_metrics(metrics, tool_call.function.name) if len(_req.input_ids) >= self.config.max_model_len: finish_reason_type = FinishReasonTypeEnum.STOP break _req.state = AsyncRolloutRequestStateEnum.RUNNING else: - raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") + raise ValueError( + f"Unexpected tool calling last message state: {_req.messages[-1]}" + ) elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: # Only continue the conversation if the prompt length is not greater than max_model_len - 1, # since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra # token accounts for the EOS token). - if len(_req.get_generation_prompt_ids(self.processing_class)) + 1 >= self.config.max_model_len: + if ( + len(_req.get_generation_prompt_ids(self.processing_class)) + 1 + >= self.config.max_model_len + ): finish_reason_type = FinishReasonTypeEnum.LENGTH break @@ -881,22 +964,32 @@ async def _async_rollout_a_request( ) if video_data: logger.warning( - "video support is not implemented yet, current length of video data is %d", len(video_data) + "video support is not implemented yet, current length of video data is %d", + len(video_data), ) - output = await self._handle_engine_call(_req, request_sampling_params, image_data=image_data) + output = await self._handle_engine_call( + _req, request_sampling_params, image_data=image_data + ) content = output["text"] - finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"]) + finish_reason_type = FinishReasonTypeEnum.from_str( + output["meta_info"]["finish_reason"]["type"] + ) current_turns += 1 if finish_reason_type == FinishReasonTypeEnum.LENGTH: _req.add_assistant_message(self.processing_class, content) break else: - if self._function_call_parser and self._function_call_parser.has_tool_call(content): + if ( + self._function_call_parser + and self._function_call_parser.has_tool_call(content) + ): finish_reason_type = FinishReasonTypeEnum.TOOL_CALL _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING try: - normed_content, tool_calls = self._function_call_parser.parse_non_stream(content) + normed_content, tool_calls = ( + self._function_call_parser.parse_non_stream(content) + ) except JSONDecodeError: normed_content = content tool_calls = [] @@ -905,10 +998,12 @@ async def _async_rollout_a_request( tool_calls = [] parsed_tool_calls = [] for tool_call in tool_calls: - function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema( - OpenAIFunctionParsedSchema( - name=tool_call.name, - arguments=tool_call.parameters, + function, has_decode_error = ( + OpenAIFunctionCallSchema.from_openai_function_parsed_schema( + OpenAIFunctionParsedSchema( + name=tool_call.name, + arguments=tool_call.parameters, + ) ) ) # Drop the tool call if its arguments has decode error @@ -922,7 +1017,9 @@ async def _async_rollout_a_request( ) if len(parsed_tool_calls) > 0: _req.add_assistant_message( - self.processing_class, normed_content, tool_calls=parsed_tool_calls + self.processing_class, + normed_content, + tool_calls=parsed_tool_calls, ) else: _req.add_assistant_message(self.processing_class, content) @@ -938,14 +1035,17 @@ async def _async_rollout_a_request( _req.interaction_kwargs and self.interaction_map and user_turns < self.config.multi_turn.max_user_turns - and current_turns < self.config.multi_turn.max_assistant_turns + and current_turns + < self.config.multi_turn.max_assistant_turns ): _req.state = AsyncRolloutRequestStateEnum.INTERACTING else: break elif _req.state == AsyncRolloutRequestStateEnum.INTERACTING: user_turns += 1 - messages = [{"role": x.role, "content": x.content} for x in _req.messages] + messages = [ + {"role": x.role, "content": x.content} for x in _req.messages + ] # Get interaction by name from interaction_kwargs interaction_name = _req.interaction_kwargs.get( @@ -958,8 +1058,10 @@ async def _async_rollout_a_request( ) interaction = self.interaction_map[interaction_name] - should_terminate_sequence, content, reward, metrics = await interaction.generate_response( - _req.request_id, messages, **_req.interaction_kwargs + should_terminate_sequence, content, reward, metrics = ( + await interaction.generate_response( + _req.request_id, messages, **_req.interaction_kwargs + ) ) user_turn_rewards.append(reward) if should_terminate_sequence: @@ -979,8 +1081,12 @@ async def _async_rollout_a_request( # Calculate the reward for each tool async def calc_reward_and_release_fn(name: str, tool: BaseTool): - reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {})) - await tool.release(_req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {})) + reward = await tool.calc_reward( + _req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {}) + ) + await tool.release( + _req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {}) + ) return name, reward tool_reward_tasks = [] @@ -995,15 +1101,26 @@ async def calc_reward_and_release_fn(name: str, tool: BaseTool): return _req async def _handle_engine_call( - self, _req: AsyncRolloutRequest, sampling_params: dict, image_data: Optional[list[Any]] = None + self, + _req: AsyncRolloutRequest, + sampling_params: dict, + image_data: Optional[list[Any]] = None, ) -> dict: generation_prompt_ids = _req.get_generation_prompt_ids(self.processing_class) - return await self._handle_engine_generate(generation_prompt_ids, sampling_params, image_data) + return await self._handle_engine_generate( + generation_prompt_ids, sampling_params, image_data + ) async def _handle_engine_generate( - self, generation_prompt_ids: list[int], sampling_params: dict, image_data: Optional[list[Any]] = None + self, + generation_prompt_ids: list[int], + sampling_params: dict, + image_data: Optional[list[Any]] = None, ) -> dict: - max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1) + max_new_tokens = min( + self.config.response_length, + self.config.max_model_len - len(generation_prompt_ids) - 1, + ) kwargs = sampling_params.copy() kwargs["max_new_tokens"] = max_new_tokens kwargs["n"] = 1 # group size is supported in preprocess @@ -1015,18 +1132,24 @@ async def _handle_engine_generate( ) return output - async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest: + async def _handle_pending_state( + self, _req: AsyncRolloutRequest + ) -> AsyncRolloutRequest: if _req.tool_schemas is not None: tool_creation_coroutines = [] for tool_schema in _req.tool_schemas: tool = self._tool_map[tool_schema.function.name] create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {}) - tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs)) + tool_creation_coroutines.append( + tool.create(_req.request_id, **create_kwargs) + ) await asyncio.gather(*tool_creation_coroutines) if _req.interaction_kwargs and self.interaction_map: interaction_kwargs = _req.interaction_kwargs # Get interaction by name from interaction_kwargs - interaction_name = interaction_kwargs.get("name", "gsm8k") # Default to gsm8k for backward compatibility + interaction_name = interaction_kwargs.get( + "name", "gsm8k" + ) # Default to gsm8k for backward compatibility if interaction_name not in self.interaction_map: raise ValueError( f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: " @@ -1066,10 +1189,17 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( asyncio.gather( - *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list], + *[ + self._async_rollout_a_request( + req, do_sample, is_validate, **kwargs + ) + for req in req_list + ], ) ) - sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset)) + sorted_output_req_list = sorted( + output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset) + ) else: sorted_output_req_list = None @@ -1091,7 +1221,9 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro multi_modal_inputs = [] for req in sorted_output_req_list: - assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed" + assert ( + req.state == AsyncRolloutRequestStateEnum.COMPLETED + ), f"Request {req.request_id} is not completed" assert ( req.input_ids.shape[-1] == req.attention_mask.shape[-1] @@ -1119,10 +1251,18 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro f"""{req.request_id=} has response_ids length {req.response_ids.shape[-1]} greater than max_response_len {self.config.response_length},\n{req=}""" ) - prompt_attention_mask.append(req.prompt_attention_mask.to(tgt_device).squeeze(0)) - response_attention_mask.append(req.response_attention_mask.to(tgt_device).squeeze(0)) - prompt_position_ids.append(req.prompt_position_ids.to(tgt_device).squeeze(0)) - response_position_ids.append(req.response_position_ids.to(tgt_device).squeeze(0)) + prompt_attention_mask.append( + req.prompt_attention_mask.to(tgt_device).squeeze(0) + ) + response_attention_mask.append( + req.response_attention_mask.to(tgt_device).squeeze(0) + ) + prompt_position_ids.append( + req.prompt_position_ids.to(tgt_device).squeeze(0) + ) + response_position_ids.append( + req.response_position_ids.to(tgt_device).squeeze(0) + ) prompt_loss_mask.append(req.prompt_loss_mask.to(tgt_device).squeeze(0)) response_loss_mask.append(req.response_loss_mask.to(tgt_device).squeeze(0)) messages.append({"messages": req.messages}) @@ -1136,10 +1276,16 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro padding_side="left", ) if prompt_ids.shape[-1] < self.config.prompt_length: - prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True) - response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id) + prompt_ids = pad_sequence_to_length( + prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True + ) + response_ids = pad_sequence( + response_ids, batch_first=True, padding_value=self.pad_token_id + ) if response_ids.shape[-1] < self.config.response_length: - response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id) + response_ids = pad_sequence_to_length( + response_ids, self.config.response_length, self.pad_token_id + ) prompt_attention_mask = pad_sequence( prompt_attention_mask, batch_first=True, @@ -1150,22 +1296,34 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro prompt_attention_mask = pad_sequence_to_length( prompt_attention_mask, self.config.prompt_length, 0, left_pad=True ) - response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0) + response_attention_mask = pad_sequence( + response_attention_mask, batch_first=True, padding_value=0 + ) if response_attention_mask.shape[-1] < self.config.response_length: - response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) + response_attention_mask = pad_sequence_to_length( + response_attention_mask, self.config.response_length, 0 + ) # padding prompt_position_ids if prompt_position_ids[0].dim() == 2: # if prompt_position_ids is a 2D tensor # e.g. from qwen2vl, prompt_position_ids.shape = (3, seq_len) - transposed_prompt_position_ids = [p.transpose(0, 1) for p in prompt_position_ids] + transposed_prompt_position_ids = [ + p.transpose(0, 1) for p in prompt_position_ids + ] prompt_position_ids = pad_sequence( - transposed_prompt_position_ids, batch_first=True, padding_value=0, padding_side="left" + transposed_prompt_position_ids, + batch_first=True, + padding_value=0, + padding_side="left", ) prompt_position_ids = prompt_position_ids.transpose(1, 2) else: prompt_position_ids = pad_sequence( - prompt_position_ids, batch_first=True, padding_value=0, padding_side="left" + prompt_position_ids, + batch_first=True, + padding_value=0, + padding_side="left", ) if prompt_position_ids.shape[-1] < self.config.prompt_length: prompt_position_ids = pad_sequence_to_length( @@ -1176,25 +1334,44 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro if response_position_ids[0].dim() == 2: # if response_position_ids is a 2D tensor # e.g. from qwen2vl, response_position_ids.shape = (3, seq_len) - transposed_response_position_ids = [p.transpose(0, 1) for p in response_position_ids] + transposed_response_position_ids = [ + p.transpose(0, 1) for p in response_position_ids + ] response_position_ids = pad_sequence( - transposed_response_position_ids, batch_first=True, padding_value=0, padding_side="left" + transposed_response_position_ids, + batch_first=True, + padding_value=0, + padding_side="left", ) response_position_ids = response_position_ids.transpose(1, 2) else: - response_position_ids = pad_sequence(response_position_ids, batch_first=True, padding_value=0) + response_position_ids = pad_sequence( + response_position_ids, batch_first=True, padding_value=0 + ) if response_position_ids.shape[-1] < self.config.response_length: - response_position_ids = pad_sequence_to_length(response_position_ids, self.config.response_length, 0) + response_position_ids = pad_sequence_to_length( + response_position_ids, self.config.response_length, 0 + ) - prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left") + prompt_loss_mask = pad_sequence( + prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left" + ) if prompt_loss_mask.shape[1] < self.config.prompt_length: - prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True) - response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0) + prompt_loss_mask = pad_sequence_to_length( + prompt_loss_mask, self.config.prompt_length, 0, left_pad=True + ) + response_loss_mask = pad_sequence( + response_loss_mask, batch_first=True, padding_value=0 + ) if response_loss_mask.shape[1] < self.config.response_length: - response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0) + response_loss_mask = pad_sequence_to_length( + response_loss_mask, self.config.response_length, 0 + ) input_ids = torch.cat((prompt_ids, response_ids), dim=-1) - attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) + attention_mask = torch.cat( + (prompt_attention_mask, response_attention_mask), dim=-1 + ) position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1) # Construct the batch data @@ -1224,10 +1401,12 @@ def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataPro }, ) - def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int = 1) -> list[AsyncRolloutRequest]: - assert "raw_prompt" in prompts.non_tensor_batch, ( - "need data.return_raw_chat=True, due to no official way do parse_messages" - ) + def _preprocess_prompt_to_async_rollout_requests( + self, prompts: DataProto, n: int = 1 + ) -> list[AsyncRolloutRequest]: + assert ( + "raw_prompt" in prompts.non_tensor_batch + ), "need data.return_raw_chat=True, due to no official way do parse_messages" logger.info( "n is deprecated for SGLang rollout since ray ppo trainer will repeat the prompts for rollout.n times" ) @@ -1237,21 +1416,34 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in ) for data_idx, (raw_prompt, multi_modal_data) in enumerate( - zip(prompts.non_tensor_batch["raw_prompt"], multi_modal_data_list, strict=True) + zip( + prompts.non_tensor_batch["raw_prompt"], + multi_modal_data_list, + strict=True, + ) ): if self._tool_schemas: _tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx] - _tool_schemas = [self._tool_map[k].get_openai_tool_schema() for k in _tools_kwargs.keys()] + _tool_schemas = [ + self._tool_map[k].get_openai_tool_schema() + for k in _tools_kwargs.keys() + ] _input_ids = None _attention_mask = None else: - _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx]) - _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx]) + _input_ids = _pre_process_inputs( + self.pad_token_id, prompts.batch["input_ids"][data_idx] + ) + _attention_mask = _pre_process_inputs( + 0, prompts.batch["attention_mask"][data_idx] + ) _tools_kwargs = {} _tool_schemas = None if self.interaction_map: - _interaction_kwargs = prompts.non_tensor_batch["interaction_kwargs"][data_idx] + _interaction_kwargs = prompts.non_tensor_batch["interaction_kwargs"][ + data_idx + ] else: _interaction_kwargs = {} @@ -1274,7 +1466,10 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in reward_scores={}, max_prompt_len=self.config.prompt_length, max_response_len=self.config.response_length, - max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), + max_model_len=min( + self.config.max_model_len, + self.config.prompt_length + self.config.response_length, + ), use_inference_chat_template=self.config.multi_turn.use_inference_chat_template, tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode, processing_class=self.processing_class, @@ -1323,7 +1518,10 @@ async def chat_completion(self, json_request): reward_scores={}, max_prompt_len=self.config.prompt_length, max_response_len=self.config.response_length, - max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), + max_model_len=min( + self.config.max_model_len, + self.config.prompt_length + self.config.response_length, + ), use_inference_chat_template=self.config.multi_turn.use_inference_chat_template, tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode, processing_class=self.processing_class, @@ -1332,9 +1530,13 @@ async def chat_completion(self, json_request): # json_request already contains sampling_params # Filter only valid SamplingParams arguments valid_sampling_params = {} - temp_sampling_params = SamplingParams() # Create temporary instance to check valid attributes + temp_sampling_params = ( + SamplingParams() + ) # Create temporary instance to check valid attributes for k, v in json_request.items(): - if k not in ["messages", "model", "tools"] and hasattr(temp_sampling_params, k): + if k not in ["messages", "model", "tools"] and hasattr( + temp_sampling_params, k + ): valid_sampling_params[k] = v output = await self._handle_engine_call(req, valid_sampling_params) # it can be Dict or AsyncIterator[Dict] diff --git a/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/utils.py b/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/utils.py index 776bd13..fbe3af6 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/utils.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/sglang_rollout/utils.py @@ -46,7 +46,9 @@ def broadcast_pyobj( serialized_data = pickle.dumps(data) size = len(serialized_data) - tensor_data = torch.ByteTensor(np.frombuffer(serialized_data, dtype=np.uint8)).to(device) + tensor_data = torch.ByteTensor( + np.frombuffer(serialized_data, dtype=np.uint8) + ).to(device) tensor_size = torch.tensor([size], dtype=torch.long, device=device) dist.broadcast(tensor_size, src=src, group=dist_group) diff --git a/Agent0/executor_train/verl/verl/workers/rollout/tokenizer.py b/Agent0/executor_train/verl/verl/workers/rollout/tokenizer.py index 1e1212e..d1c8ebb 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/tokenizer.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/tokenizer.py @@ -116,7 +116,9 @@ def decode( pass @abstractmethod - def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False) -> str | list[str]: + def convert_ids_to_tokens( + self, ids: int | list[int], skip_special_tokens: bool = False + ) -> str | list[str]: """ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and added tokens. diff --git a/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/__init__.py b/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/__init__.py index dac55e0..88be41c 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/__init__.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/__init__.py @@ -39,6 +39,8 @@ def get_version(pkg): if match: vllm_package_version = match.group(1) else: - raise ValueError(f"Warning: Could not parse version format: {vllm_package_version}") + raise ValueError( + f"Warning: Could not parse version format: {vllm_package_version}" + ) vllm_mode = "spmd" diff --git a/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 988dac4..67ec642 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -24,7 +24,11 @@ from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ErrorResponse, +) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.inputs import TokensPrompt @@ -40,21 +44,30 @@ def _get_model_runner_workers(vllm_config, init_ray: bool = True): - assert vllm_config.instance_id is not None, "instance_id must be set for external ray actors." + assert ( + vllm_config.instance_id is not None + ), "instance_id must be set for external ray actors." fields = vllm_config.instance_id.split(":") assert len(fields) == 4, ( f"instance_id: {vllm_config.instance_id} must be in the format of " f":::." ) - namespace, wg_prefix, vllm_dp_size, vllm_dp_rank = fields[0], fields[1], int(fields[2]), int(fields[3]) + namespace, wg_prefix, vllm_dp_size, vllm_dp_rank = ( + fields[0], + fields[1], + int(fields[2]), + int(fields[3]), + ) # Make sure subprocess in same namespace as parent actor. # actor name format: {name_prefix}WorkerDict_{pg_idx}:{local_rank} if init_ray: ray.init(namespace=namespace) actor_names = [ - actor_name for actor_name in ray.util.list_named_actors() if actor_name.startswith(f"{wg_prefix}WorkerDict") + actor_name + for actor_name in ray.util.list_named_actors() + if actor_name.startswith(f"{wg_prefix}WorkerDict") ] vllm_tp_size = vllm_config.parallel_config.tensor_parallel_size @@ -71,9 +84,15 @@ def get_pg_index_and_local_rank(actor_name) -> tuple[int, int]: # sort actor names by pg_index and local_rank actor_names = sorted(actor_names, key=get_pg_index_and_local_rank) - actor_names = actor_names[vllm_dp_rank * vllm_tp_size : (vllm_dp_rank + 1) * vllm_tp_size] - workers: list[WorkerWrapperBase] = [ray.get_actor(actor_name) for actor_name in actor_names] - print(f"instance_id: {vllm_config.instance_id} initializes with external actors: {actor_names}") + actor_names = actor_names[ + vllm_dp_rank * vllm_tp_size : (vllm_dp_rank + 1) * vllm_tp_size + ] + workers: list[WorkerWrapperBase] = [ + ray.get_actor(actor_name) for actor_name in actor_names + ] + print( + f"instance_id: {vllm_config.instance_id} initializes with external actors: {actor_names}" + ) return workers @@ -84,7 +103,9 @@ class ExternalRayDistributedExecutor(Executor): uses_ray: bool = False def _init_executor(self) -> None: - self.workers = _get_model_runner_workers(vllm_config=self.vllm_config, init_ray=True) + self.workers = _get_model_runner_workers( + vllm_config=self.vllm_config, init_ray=True + ) kwargs = dict( vllm_config=self.vllm_config, @@ -114,7 +135,10 @@ def collective_rpc( # ~3ms overhead per schedule step due to SchedulerOutput/ModelRunnerOutput serialization/deserialization. outputs = ray.get( - [worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers] + [ + worker.execute_method.remote(sent_method, *args, **(kwargs or {})) + for worker in self.workers + ] ) return outputs @@ -190,7 +214,9 @@ class AsyncvLLMServer(AsyncServerBase): For vLLM AsyncLLM design, see: https://github.com/vllm-project/vllm/pull/9826 """ - def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_prefix: str): + def __init__( + self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_prefix: str + ): """ Args: config: DictConfig. @@ -217,7 +243,11 @@ async def init_engine(self): tensor_parallel_size = config.get("tensor_model_parallel_size", 1) max_num_batched_tokens = config.get("max_num_batched_tokens", 8192) - max_model_len = config.max_model_len if config.max_model_len else config.prompt_length + config.response_length + max_model_len = ( + config.max_model_len + if config.max_model_len + else config.prompt_length + config.response_length + ) self.max_model_len = int(max_model_len) # Override default generation config from hugging face model config, @@ -285,12 +315,19 @@ async def init_engine(self): def _create_engine_config(self, engine_args: AsyncEngineArgs): vllm_config = engine_args.create_engine_config() namespace = ray.get_runtime_context().namespace - vllm_config.instance_id = f"{namespace}:{self.wg_prefix}:{self.vllm_dp_size}:{self.vllm_dp_rank}" + vllm_config.instance_id = ( + f"{namespace}:{self.wg_prefix}:{self.vllm_dp_size}:{self.vllm_dp_rank}" + ) # VERL_VLLM_ZMQ_ADDRESSES - if engine_args.distributed_executor_backend == ExternalZeroMQDistributedExecutor: + if ( + engine_args.distributed_executor_backend + == ExternalZeroMQDistributedExecutor + ): workers = _get_model_runner_workers(vllm_config=vllm_config, init_ray=False) - zmq_addresses = ray.get([worker.get_zeromq_address.remote() for worker in workers]) + zmq_addresses = ray.get( + [worker.get_zeromq_address.remote() for worker in workers] + ) print(f"VERL_VLLM_ZMQ_ADDRESSES: {zmq_addresses}") os.environ["VERL_VLLM_ZMQ_ADDRESSES"] = ",".join(zmq_addresses) @@ -303,21 +340,29 @@ async def chat_completion(self, raw_request: Request): """ request_json = await raw_request.json() request = ChatCompletionRequest(**request_json) - generator = await self.openai_serving_chat.create_chat_completion(request, raw_request) + generator = await self.openai_serving_chat.create_chat_completion( + request, raw_request + ) if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), status_code=generator.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.code + ) if request.stream: return StreamingResponse(content=generator, media_type="text/event-stream") else: assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) - async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: + async def generate( + self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str + ) -> list[int]: max_tokens = self.max_model_len - len(prompt_ids) sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params) prompt = TokensPrompt(prompt_token_ids=prompt_ids) - generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id) + generator = self.engine.generate( + prompt=prompt, sampling_params=sampling_params, request_id=request_id + ) # Get final response final_res: Optional[RequestOutput] = None diff --git a/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index af637c1..275b770 100644 --- a/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/Agent0/executor_train/verl/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -69,13 +69,17 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[in # remove the left padding in the prompt token_id # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id # is not None else self.llm_engine.tokenizer.eos_token_id - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][ + 0 + ] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids class vLLMRollout(BaseRollout): - def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): + def __init__( + self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs + ): """A vLLM rollout. It requires the module is supported by the vllm. Args: @@ -89,9 +93,9 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf self.config = config tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert tensor_parallel_size <= torch.distributed.get_world_size(), ( - "tensor parallel size should be less than or equal to the world size" - ) + assert ( + tensor_parallel_size <= torch.distributed.get_world_size() + ), "tensor parallel size should be less than or equal to the world size" max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192) if kwargs.get("train_tp") is not None: @@ -100,7 +104,9 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" os.environ["MEGATRON_IMPORT_TIMERS"] = "0" - vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size) + vllm_ps.initialize_model_parallel( + tensor_model_parallel_size=tensor_parallel_size + ) rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) if not rope_scaling_config: @@ -110,16 +116,20 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf elif hasattr(model_hf_config, "llm_config") and hasattr( model_hf_config.llm_config, "max_position_embeddings" ): - max_position_embeddings = model_hf_config.llm_config.max_position_embeddings + max_position_embeddings = ( + model_hf_config.llm_config.max_position_embeddings + ) elif hasattr(model_hf_config, "text_config") and hasattr( model_hf_config.text_config, "max_position_embeddings" ): - max_position_embeddings = model_hf_config.text_config.max_position_embeddings + max_position_embeddings = ( + model_hf_config.text_config.max_position_embeddings + ) if max_position_embeddings is None: raise ValueError("max_position_embeddings not found in model_hf_config") - assert max_position_embeddings >= config.prompt_length + config.response_length, ( - "model context length should be greater than total sequence length" - ) + assert ( + max_position_embeddings >= config.prompt_length + config.response_length + ), "model context length should be greater than total sequence length" else: # handle type where there's a length extend factor # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support @@ -135,16 +145,23 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf + f"max_position_embeddings={model_hf_config.max_position_embeddings}" ) - max_model_len = int(config.max_model_len or config.prompt_length + config.response_length) + max_model_len = int( + config.max_model_len or config.prompt_length + config.response_length + ) - if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill: + if ( + max_num_batched_tokens < max_model_len + and self.config.enable_chunked_prefill + ): raise ValueError( "Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ please increase max_num_batched_tokens or disable chunked prefill" ) trust_remote_code = kwargs.get("trust_remote_code", False) - load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format + load_format = ( + "dummy" if config.load_format.startswith("dummy") else config.load_format + ) lora_kwargs = kwargs.pop("lora_kwargs", {}) self.lora_kwargs = lora_kwargs @@ -158,7 +175,9 @@ def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_conf # - `None` means not setting it, so we pop it, and leave it to vLLM default value # (which can vary across different vLLM versions); # - Otherwise it's the desired value we want to explicitly set. - engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} + engine_kwargs = { + key: val for key, val in engine_kwargs.items() if val is not None + } if config.get("limit_images", None): # support for multi-image data engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")} @@ -258,7 +277,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: non_tensor_batch = prompts.non_tensor_batch if "raw_prompt_ids" not in non_tensor_batch: non_tensor_batch["raw_prompt_ids"] = np.array( - [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object + [ + _pre_process_inputs(self.pad_token_id, idx[i]) + for i in range(batch_size) + ], + dtype=object, ) if batch_size != len(non_tensor_batch["raw_prompt_ids"]): @@ -267,12 +290,20 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: if "multi_modal_data" in non_tensor_batch: vllm_inputs = [] for raw_prompt_ids, multi_modal_data in zip( - non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data"), strict=True + non_tensor_batch.pop("raw_prompt_ids"), + non_tensor_batch.pop("multi_modal_data"), + strict=True, ): - vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data}) + vllm_inputs.append( + { + "prompt_token_ids": raw_prompt_ids, + "multi_modal_data": multi_modal_data, + } + ) else: vllm_inputs = [ - {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") + {"prompt_token_ids": raw_prompt_ids} + for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") ] # ensure the type of `prompt_token_ids` passed to vllm is list[int] @@ -311,7 +342,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: if len(lora_int_ids) > 0: lora_int_id = lora_int_ids[0] lora_requests = [ - LoRARequest(lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/simon-stub-path") + LoRARequest( + lora_name=f"{lora_int_id}", + lora_int_id=lora_int_id, + lora_path="/simon-stub-path", + ) ] * batch_size # users can customize different sampling_params at different run @@ -338,9 +373,9 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: curr_log_prob.append(logprob[response_ids[i]].logprob) rollout_log_probs.append(curr_log_prob) - response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to( - idx.device - ) + response = pad_2d_list_to_length( + response, self.pad_token_id, max_length=self.config.response_length + ).to(idx.device) if self.config.calculate_log_probs: rollout_log_probs = pad_2d_list_to_length( rollout_log_probs, -1, max_length=self.config.response_length @@ -350,10 +385,14 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: seq = torch.cat([idx, response], dim=-1) response_length = response.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) + delta_position_id = torch.arange( + 1, response_length + 1, device=position_ids.device + ) delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1) if position_ids.dim() == 3: # qwen2vl mrope - delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) + delta_position_id = delta_position_id.view(batch_size, 1, -1).expand( + batch_size, 3, -1 + ) # TODO(sgm): fix position_ids on right_pad # prompt: left pad + response: right pad @@ -405,7 +444,9 @@ class vLLMAsyncRollout: which is engine in single worker process. """ - def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): + def __init__( + self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs + ): self.tokenizer = tokenizer # Engine is deferred to be initialized in init_worker @@ -472,7 +513,9 @@ def load_model(self, *args, **kwargs): self.sharding_manager.inference_engine = self.inference_engine self.sharding_manager.model_runner = self.inference_engine.worker.model_runner - _monkey_patch_compute_logits(self.inference_engine.worker.model_runner.model, len(self.tokenizer)) + _monkey_patch_compute_logits( + self.inference_engine.worker.model_runner.model, len(self.tokenizer) + ) def sleep(self, *args, **kwargs): """Offload model weights and discard kv cache.""" diff --git a/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_sglang.py b/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_sglang.py index be74bbd..77bc3ac 100644 --- a/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_sglang.py +++ b/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_sglang.py @@ -24,14 +24,24 @@ from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.utils import MultiprocessingSerializer from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import ( + FullStateDictConfig, + ShardedStateDictConfig, + StateDictType, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel as FSDP, +) from torch.distributed.tensor import DTensor from verl import DataProto from verl.protocol import all_gather_data_proto from verl.utils.device import get_device_id, get_torch_device -from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu +from verl.utils.fsdp_utils import ( + fsdp_version, + load_fsdp_model_to_gpu, + offload_fsdp_model_to_cpu, +) from verl.utils.model import convert_weight_keys from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from verl.utils.torch_functional import check_device_is_available @@ -74,7 +84,9 @@ def __init__( self.full_params = full_params if full_params and fsdp_version(self.module) == 1: FSDP.set_state_dict_type( - self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig() + self.module, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(), ) elif fsdp_version(self.module) == 1: FSDP.set_state_dict_type( @@ -91,7 +103,9 @@ def __init__( # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + get_torch_device().manual_seed( + gen_dp_rank + 1000 + ) # make sure all tp ranks have the same random states self.gen_random_states = get_torch_device().get_rng_state() get_torch_device().set_rng_state(self.torch_random_states) else: @@ -114,10 +128,14 @@ async def update_weights(self, params): named_tensors = [(k, v) for k, v in params.items()] load_format = None for tensor_index, (name, tensor) in enumerate(named_tensors): - serialized_tensor = MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor)) + serialized_tensor = MultiprocessingSerializer.serialize( + _preprocess_tensor_for_update_weights(tensor) + ) if self.device_mesh["infer_tp"].get_local_rank() == 0: - gathered_serialized_tensors = [None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0])] + gathered_serialized_tensors = [ + None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0]) + ] else: gathered_serialized_tensors = None dist.gather_object( @@ -140,43 +158,65 @@ async def update_weights(self, params): ) async def release_memory(self): - if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + if ( + self.device_mesh["infer_tp"].get_local_rank() == 0 + and self.rollout_config.free_cache_engine + ): await self.inference_engine.release_memory_occupation() @GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger) async def wake_up(self): get_torch_device().empty_cache() - if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + if ( + self.device_mesh["infer_tp"].get_local_rank() == 0 + and self.rollout_config.free_cache_engine + ): if self.multi_stage_wake_up: await self.inference_engine.resume_memory_occupation(tags=["weights"]) - log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) + log_gpu_memory_usage( + "Before resume SGLang weights in sharding manager", logger=logger + ) else: await self.inference_engine.resume_memory_occupation() - log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) + log_gpu_memory_usage( + "Before resume SGLang weights + kv_cache in sharding manager", + logger=logger, + ) - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + log_gpu_memory_usage( + "Before state_dict() in sharding manager memory", logger=logger + ) if self.offload_param: load_fsdp_model_to_gpu(self.module) params = self.module.state_dict() - log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) + log_gpu_memory_usage( + "After state_dict() in sharding manager memory", logger=logger + ) device = get_device_id() # used when fsdp2 set cpu_offload_policy params = { - k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items() + k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v + for k, v in params.items() } # convert weight keys to match the model config - params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) + params = convert_weight_keys( + params, getattr(self.module, "_fsdp_wrapped_module", self.module) + ) # Copy, not share memory await self.update_weights(params) - log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) + log_gpu_memory_usage( + "After sync model weights in sharding manager", logger=logger + ) del params if self.offload_param: offload_fsdp_model_to_cpu(self.module) get_torch_device().empty_cache() - log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) + log_gpu_memory_usage( + "After del state_dict and empty_cache in sharding manager", logger=logger + ) if ( self.multi_stage_wake_up @@ -184,7 +224,9 @@ async def wake_up(self): and self.device_mesh["infer_tp"].get_local_rank() == 0 ): await self.inference_engine.resume_memory_occupation(tags=["kv_cache"]) - log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger) + log_gpu_memory_usage( + "After resume SGLang kv_cache in sharding manager", logger=logger + ) # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: @@ -194,9 +236,13 @@ async def wake_up(self): @GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) async def sleep(self): if self.rollout_config.free_cache_engine: - log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) + log_gpu_memory_usage( + "Before SGLang offload in sharding manager", logger=logger + ) await self.release_memory() - log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) + log_gpu_memory_usage( + "After SGLang offload in sharding manager", logger=logger + ) self.module.train() diff --git a/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_ulysses.py b/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_ulysses.py index 39ccb77..f45804f 100644 --- a/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_ulysses.py +++ b/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_ulysses.py @@ -19,7 +19,10 @@ from verl import DataProto from verl.protocol import all_gather_data_proto -from verl.utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group +from verl.utils.ulysses import ( + get_ulysses_sequence_parallel_group, + set_ulysses_sequence_parallel_group, +) from .base import BaseShardingManager diff --git a/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_vllm.py b/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_vllm.py index 1a9677d..2cf3ee1 100644 --- a/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_vllm.py +++ b/Agent0/executor_train/verl/verl/workers/sharding_manager/fsdp_vllm.py @@ -19,8 +19,14 @@ from collections import OrderedDict from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import ( + FullStateDictConfig, + ShardedStateDictConfig, + StateDictType, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullyShardedDataParallel as FSDP, +) try: # for torch 2.5+ @@ -41,10 +47,19 @@ load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu, ) -from verl.utils.model import check_exclude_modules, check_target_modules, convert_weight_keys +from verl.utils.model import ( + check_exclude_modules, + check_target_modules, + convert_weight_keys, +) from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer from verl.utils.torch_functional import check_device_is_available -from verl.utils.vllm_utils import TensorLoRARequest, VLLMHijack, is_version_ge, patch_vllm_moe_model_weight_loader +from verl.utils.vllm_utils import ( + TensorLoRARequest, + VLLMHijack, + is_version_ge, + patch_vllm_moe_model_weight_loader, +) from .base import BaseShardingManager @@ -96,7 +111,9 @@ def __init__( self.full_params = full_params if full_params and fsdp_version(self.module) == 1: FSDP.set_state_dict_type( - self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig() + self.module, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(), ) elif fsdp_version(self.module) == 1: FSDP.set_state_dict_type( @@ -113,7 +130,9 @@ def __init__( # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + get_torch_device().manual_seed( + gen_dp_rank + 1000 + ) # make sure all tp ranks have the same random states self.gen_random_states = get_torch_device().get_rng_state() get_torch_device().set_rng_state(self.torch_random_states) else: @@ -147,19 +166,27 @@ def __collect_lora_params() -> OrderedDict: if self.base_sync_done: lora_params = get_peft_model_state_dict(peft_model) lora_params = { - name: param.full_tensor().detach().cpu() - if hasattr(param, "full_tensor") - else param.detach().cpu() + name: ( + param.full_tensor().detach().cpu() + if hasattr(param, "full_tensor") + else param.detach().cpu() + ) for name, param in lora_params.items() } else: model = peft_model.base_model.model - orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() + orig_dev = ( + "cpu" + if "cpu" in str(next(model.parameters()).device) + else get_device_name() + ) model = model.to("cpu") for name, param in model.state_dict().items(): if any(x in name for x in ["_flat_param", "lora_"]): continue - name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") + name = name.replace( + "_fsdp_wrapped_module.", "" + ).replace(".base_layer", "") lora_params[name] = ( param.full_tensor().detach().cpu() if hasattr(param, "full_tensor") @@ -172,12 +199,18 @@ def __collect_lora_params() -> OrderedDict: lora_params = get_peft_model_state_dict(peft_model) else: model = peft_model.base_model.model - orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() + orig_dev = ( + "cpu" + if "cpu" in str(next(model.parameters()).device) + else get_device_name() + ) model = model.to("cpu") for name, param in model.state_dict().items(): if any(x in name for x in ["_flat_param", "lora_"]): continue - name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") + name = name.replace("_fsdp_wrapped_module.", "").replace( + ".base_layer", "" + ) lora_params[name] = param.detach().cpu() model = model.to(orig_dev) return lora_params @@ -193,7 +226,9 @@ def __collect_lora_params() -> OrderedDict: with simple_timer("reshard", self.timing): get_torch_device().empty_cache() - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + log_gpu_memory_usage( + "Before state_dict() in sharding manager memory", logger=logger + ) if self.offload_param: load_fsdp_model_to_gpu(self.module) @@ -204,18 +239,27 @@ def __collect_lora_params() -> OrderedDict: params = __collect_lora_params() else: params = self.module.state_dict() - params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) - log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) + params = convert_weight_keys( + params, getattr(self.module, "_fsdp_wrapped_module", self.module) + ) + log_gpu_memory_usage( + "After state_dict() in sharding manager memory", logger=logger + ) if self.rollout_config.free_cache_engine: - if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: + if ( + "tags" + in inspect.signature(self.inference_engine.wake_up).parameters + ): self.inference_engine.wake_up(tags=["weights"]) else: self.inference_engine.wake_up() # update model params self.update_params(params, peft_config=peft_config) - log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) + log_gpu_memory_usage( + "After sync model weights in sharding manager", logger=logger + ) del params if self.offload_param: offload_fsdp_model_to_cpu(self.module) @@ -223,11 +267,15 @@ def __collect_lora_params() -> OrderedDict: if ( self.rollout_config.free_cache_engine - and "tags" in inspect.signature(self.inference_engine.wake_up).parameters + and "tags" + in inspect.signature(self.inference_engine.wake_up).parameters ): self.inference_engine.wake_up(tags=["kv_cache"]) - log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) + log_gpu_memory_usage( + "After del state_dict and empty_cache in sharding manager", + logger=logger, + ) # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: @@ -308,35 +356,54 @@ def replace_lora_wrapper(k): Returns: str: Transformed parameter key for base layer. """ - stacked_params = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + stacked_params = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] if k.endswith(".weight"): module_k = k[: -len(".weight")] if check_exclude_modules(peft_config, module_k): return k - elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules( - peft_config, module_k - ): + elif any( + [module_k.endswith(s) for s in stacked_params] + ) or check_target_modules(peft_config, module_k): return f"{module_k}.base_layer.weight" if k.endswith(".bias"): module_k = k[: -len(".bias")] if check_exclude_modules(peft_config, module_k): return k - elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules( - peft_config, module_k - ): + elif any( + [module_k.endswith(s) for s in stacked_params] + ) or check_target_modules(peft_config, module_k): return f"{module_k}.base_layer.bias" return k - updated_params = {replace_lora_wrapper(k): v for k, v in updated_params.items()} + updated_params = { + replace_lora_wrapper(k): v for k, v in updated_params.items() + } patch_vllm_moe_model_weight_loader(model) device = get_device_id() # used when fsdp2 set cpu_offload_policy loaded_params = model.load_weights( ( - (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) + ( + name, + ( + param.to(device, non_blocking=True).full_tensor() + if isinstance(param, DTensor) + else param + ), + ) for name, param in updated_params.items() ) ) self.base_sync_done = True - logger.info(f"vLLM load weights, loaded_params: {len(loaded_params) if loaded_params else -1}") + logger.info( + f"vLLM load weights, loaded_params: {len(loaded_params) if loaded_params else -1}" + ) diff --git a/Agent0/executor_train/verl/verl/workers/sharding_manager/megatron_sglang.py b/Agent0/executor_train/verl/verl/workers/sharding_manager/megatron_sglang.py index 9bcc1f0..415e987 100644 --- a/Agent0/executor_train/verl/verl/workers/sharding_manager/megatron_sglang.py +++ b/Agent0/executor_train/verl/verl/workers/sharding_manager/megatron_sglang.py @@ -111,7 +111,9 @@ def __init__( # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + get_torch_device().manual_seed( + gen_dp_rank + 1000 + ) # make sure all tp ranks have the same random states self.gen_random_states = get_torch_device().get_rng_state() get_torch_device().set_rng_state(self.torch_random_states) else: @@ -130,7 +132,10 @@ def __exit__(self, exc_type, exc_value, traceback): loop.run_until_complete(self.sleep()) async def update_weights(self, params): - if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + if ( + self.device_mesh["tp"].get_local_rank() == 0 + and self.rollout_config.free_cache_engine + ): await self.inference_engine.resume_memory_occupation() named_tensors = params load_format = None @@ -138,7 +143,9 @@ async def update_weights(self, params): serialized_tensor = MultiprocessingSerializer.serialize(tensor.detach()) if self.device_mesh["tp"].get_local_rank() == 0: - gathered_serialized_tensors = [None for _ in range(self.device_mesh["tp"].mesh.size()[0])] + gathered_serialized_tensors = [ + None for _ in range(self.device_mesh["tp"].mesh.size()[0]) + ] else: gathered_serialized_tensors = None dist.gather_object( @@ -163,7 +170,10 @@ async def update_weights(self, params): await self.inference_engine.flush_cache() async def release_memory(self): - if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + if ( + self.device_mesh["tp"].get_local_rank() == 0 + and self.rollout_config.free_cache_engine + ): await self.inference_engine.release_memory_occupation() @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) @@ -192,9 +202,13 @@ async def wake_up(self): @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) async def sleep(self): if self.rollout_config.free_cache_engine: - log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) + log_gpu_memory_usage( + "Before SGLang offload in sharding manager", logger=logger + ) await self.release_memory() - log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) + log_gpu_memory_usage( + "After SGLang offload in sharding manager", logger=logger + ) for model in self.actor_module: model.train() @@ -219,4 +233,6 @@ def postprocess_data(self, data: DataProto) -> DataProto: # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp if self.infer_tp_size == 1: return data - return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()] + return data.chunk(chunks=self.infer_tp_size)[ + self.device_mesh["tp"].get_local_rank() + ] diff --git a/Agent0/executor_train/verl/verl/workers/sharding_manager/megatron_vllm.py b/Agent0/executor_train/verl/verl/workers/sharding_manager/megatron_vllm.py index b04352c..13e62a8 100644 --- a/Agent0/executor_train/verl/verl/workers/sharding_manager/megatron_vllm.py +++ b/Agent0/executor_train/verl/verl/workers/sharding_manager/megatron_vllm.py @@ -31,7 +31,11 @@ from verl.third_party.vllm import LLM from verl.third_party.vllm import parallel_state as vllm_ps from verl.utils.device import get_torch_device -from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator +from verl.utils.megatron_utils import ( + load_megatron_model_to_gpu, + offload_megatron_model_to_cpu, + per_tensor_generator, +) from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage from verl.utils.profiler.performance import simple_timer from verl.utils.torch_functional import check_device_is_available @@ -133,7 +137,9 @@ def __init__( self.torch_random_states = get_torch_device().get_rng_state() if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + get_torch_device().manual_seed( + gen_dp_rank + 1000 + ) # make sure all tp ranks have the same random states self.gen_random_states = get_torch_device().get_rng_state() get_torch_device().set_rng_state(self.torch_random_states) else: @@ -145,12 +151,17 @@ def __enter__(self): with simple_timer("reshard", self.timing): get_torch_device().empty_cache() - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + log_gpu_memory_usage( + "Before state_dict() in sharding manager memory", logger=logger + ) if self.offload_param: load_megatron_model_to_gpu(self.actor_module) if self.rollout_config.free_cache_engine: - if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: + if ( + "tags" + in inspect.signature(self.inference_engine.wake_up).parameters + ): self.inference_engine.wake_up(tags=["weights"]) else: self.inference_engine.wake_up() @@ -176,7 +187,8 @@ def __enter__(self): if ( self.rollout_config.free_cache_engine - and "tags" in inspect.signature(self.inference_engine.wake_up).parameters + and "tags" + in inspect.signature(self.inference_engine.wake_up).parameters ): self.inference_engine.wake_up(tags=["kv_cache"]) diff --git a/Agent0/executor_train/verl_tool/llm_agent/__init__.py b/Agent0/executor_train/verl_tool/llm_agent/__init__.py index 2766530..740673c 100644 --- a/Agent0/executor_train/verl_tool/llm_agent/__init__.py +++ b/Agent0/executor_train/verl_tool/llm_agent/__init__.py @@ -1,2 +1,2 @@ from .config import AgentActorConfig -from .manager import AgentActorManager \ No newline at end of file +from .manager import AgentActorManager diff --git a/Agent0/executor_train/verl_tool/llm_agent/config.py b/Agent0/executor_train/verl_tool/llm_agent/config.py index 30c481d..edcf749 100644 --- a/Agent0/executor_train/verl_tool/llm_agent/config.py +++ b/Agent0/executor_train/verl_tool/llm_agent/config.py @@ -1,35 +1,48 @@ from dataclasses import dataclass + @dataclass class AgentActorConfig: - enable_agent: bool=True - max_turns: int=0 - min_turns: int=0 - max_start_length: int=None - max_prompt_length: int=None - max_response_length: int=None - max_model_len: int=None # Maximum model length, used for async rollout to limit the input length. - max_obs_length: int=None - max_action_length: int=None + enable_agent: bool = True + max_turns: int = 0 + min_turns: int = 0 + max_start_length: int = None + max_prompt_length: int = None + max_response_length: int = None + max_model_len: int = ( + None # Maximum model length, used for async rollout to limit the input length. + ) + max_obs_length: int = None + max_action_length: int = None tool_server_url: str = None - n: int=1 - truncate_obs_side: str='left' - truncate_response_side: str='left' - rolling_with_prompt: bool=False - call_tool_first: bool=False - action_stop_tokens: list=None - additional_eos_token_ids: list=None - mask_observations: bool=True - force_finish_for_last_turn: bool=False - enable_mtrl: bool=False - mtrl_role: str="user" - mtrl_sep: str=None # "\n<|im_start|>system\n{obs}<|im_end|>\n<|im_start|>assistant\n" - assistant_role: str="assistant" - turn_end_token: str="<|im_end|>" - rollout_mode: str="async" # "sync" or "async" - mask_overlong_loss: bool=False # whether to mask the overlong trajectory to not train on it - max_concurrent_trajectories: int=256 # Maximum number of concurrent trajectories for async rollout. If None, no limit is applied. - enable_tqdm: bool=True # Whether to enable tqdm for async rollout. - over_sampling: bool=False # Whether to over-sample the trajectories in async rollout. - tool_call_time_out: int=None # Timeout for tool calls in async rollout. - tool_call_max_retries: int=5 # Maximum number of retries for tool calls in async rollout. \ No newline at end of file + n: int = 1 + truncate_obs_side: str = "left" + truncate_response_side: str = "left" + rolling_with_prompt: bool = False + call_tool_first: bool = False + action_stop_tokens: list = None + additional_eos_token_ids: list = None + mask_observations: bool = True + force_finish_for_last_turn: bool = False + enable_mtrl: bool = False + mtrl_role: str = "user" + mtrl_sep: str = ( + None # "\n<|im_start|>system\n{obs}<|im_end|>\n<|im_start|>assistant\n" + ) + assistant_role: str = "assistant" + turn_end_token: str = "<|im_end|>" + rollout_mode: str = "async" # "sync" or "async" + mask_overlong_loss: bool = ( + False # whether to mask the overlong trajectory to not train on it + ) + max_concurrent_trajectories: int = ( + 256 # Maximum number of concurrent trajectories for async rollout. If None, no limit is applied. + ) + enable_tqdm: bool = True # Whether to enable tqdm for async rollout. + over_sampling: bool = ( + False # Whether to over-sample the trajectories in async rollout. + ) + tool_call_time_out: int = None # Timeout for tool calls in async rollout. + tool_call_max_retries: int = ( + 5 # Maximum number of retries for tool calls in async rollout. + ) diff --git a/Agent0/executor_train/verl_tool/llm_agent/manager.py b/Agent0/executor_train/verl_tool/llm_agent/manager.py index 8aff31a..a696cf4 100644 --- a/Agent0/executor_train/verl_tool/llm_agent/manager.py +++ b/Agent0/executor_train/verl_tool/llm_agent/manager.py @@ -24,7 +24,13 @@ from .tensor_helper import TensorHelper, TensorConfig from PIL import Image from .utils import PerformanceTimer, nested_copy -from .vision_utils import encode_image, encode_image_url, encode_video_url, decode_image_url, decode_video_url +from .vision_utils import ( + encode_image, + encode_image_url, + encode_video_url, + decode_image_url, + decode_video_url, +) logger = logging.getLogger(__file__) @@ -32,9 +38,10 @@ # other C0 control characters except common whitespace). CONTROL_CHAR_RE = re.compile( # this matches U+0000 through U+001F, excluding tab(09), LF(0A), CR(0D) - r'[\x00-\x08\x0B\x0C\x0E-\x1F]' + r"[\x00-\x08\x0B\x0C\x0E-\x1F]" ) + def sanitize_request(obj: Any) -> Any: """ Recursively walk through obj and: @@ -46,13 +53,15 @@ def sanitize_request(obj: Any) -> Any: if isinstance(obj, np.ndarray): obj = obj.tolist() if isinstance(obj, dict): - return {sanitize_request(key): sanitize_request(val) for key, val in obj.items()} + return { + sanitize_request(key): sanitize_request(val) for key, val in obj.items() + } elif isinstance(obj, (list, tuple)): return type(obj)(sanitize_request(item) for item in obj) elif isinstance(obj, str): # strip NUL (\x00) and other C0 control chars - return CONTROL_CHAR_RE.sub('', obj) - elif isinstance(obj,Image.Image): + return CONTROL_CHAR_RE.sub("", obj) + elif isinstance(obj, Image.Image): return encode_image(obj) else: return obj @@ -74,43 +83,71 @@ def __init__( self.config = config # self.logger = logger self.is_validation = is_validation - self.eos_token_id = self.generation_config.eos_token_id \ - if self.generation_config is not None else self.tokenizer.eos_token_id - self.tensor_fn = TensorHelper(TensorConfig( - pad_token_id=self.tokenizer.pad_token_id, - max_prompt_length=config.max_prompt_length, - max_obs_length=config.max_obs_length, - max_start_length=config.max_start_length, - max_response_length=config.max_response_length, - )) + self.eos_token_id = ( + self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id + ) + self.tensor_fn = TensorHelper( + TensorConfig( + pad_token_id=self.tokenizer.pad_token_id, + max_prompt_length=config.max_prompt_length, + max_obs_length=config.max_obs_length, + max_start_length=config.max_start_length, + max_response_length=config.max_response_length, + ) + ) if self.config.action_stop_tokens is not None: if os.path.exists(self.config.action_stop_tokens): - with open(self.config.action_stop_tokens, 'r') as f: - self.action_stop_tokens = [x for x in f.read().split(',') if x] + with open(self.config.action_stop_tokens, "r") as f: + self.action_stop_tokens = [x for x in f.read().split(",") if x] logger.info(f"Using action stop tokens: {self.action_stop_tokens}") else: - raise ValueError(f"action_stop_tokens file not found: {self.config.action_stop_tokens}") + raise ValueError( + f"action_stop_tokens file not found: {self.config.action_stop_tokens}" + ) else: self.action_stop_tokens = [] self.additional_eos_token_ids = self.config.additional_eos_token_ids if isinstance(self.additional_eos_token_ids, str): - self.additional_eos_token_ids = [int(x) for x in self.additional_eos_token_ids.split(',')] - elif isinstance(self.additional_eos_token_ids, list) or isinstance(self.additional_eos_token_ids, omegaconf.listconfig.ListConfig): - self.additional_eos_token_ids = [int(x) for x in self.additional_eos_token_ids] + self.additional_eos_token_ids = [ + int(x) for x in self.additional_eos_token_ids.split(",") + ] + elif isinstance(self.additional_eos_token_ids, list) or isinstance( + self.additional_eos_token_ids, omegaconf.listconfig.ListConfig + ): + self.additional_eos_token_ids = [ + int(x) for x in self.additional_eos_token_ids + ] elif self.additional_eos_token_ids is None: self.additional_eos_token_ids = [] if self.config.mtrl_sep is None: messages = [{"role": "system", "content": "{obs}"}] - self.config.mtrl_sep = "\n" + self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - self.config.mtrl_sep = self.config.mtrl_sep.replace("system", self.config.mtrl_role) - self.max_action_length = self.config.max_action_length if self.config.max_action_length is not None else 0 - self.max_model_len = int(config.max_model_len or config.max_prompt_length + config.max_response_length) + self.config.mtrl_sep = "\n" + self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + self.config.mtrl_sep = self.config.mtrl_sep.replace( + "system", self.config.mtrl_role + ) + self.max_action_length = ( + self.config.max_action_length + if self.config.max_action_length is not None + else 0 + ) + self.max_model_len = int( + config.max_model_len + or config.max_prompt_length + config.max_response_length + ) self.tokenizer_lock = asyncio.Lock() # for multimodal processing if self.processor: self.mm_prefix, self.mm_postfix = self.processor.apply_chat_template( [{"role": "system", "content": [{"type": "text", "text": "|||"}]}], - tokenize=False, add_generation_prompt=False).split("|||") # this is used to create the correct multi-modal prompt + tokenize=False, + add_generation_prompt=False, + ).split( + "|||" + ) # this is used to create the correct multi-modal prompt else: self.mm_prefix = "" self.mm_postfix = "" @@ -120,28 +157,27 @@ def __init__( logger.setLevel(logging.WARNING) @classmethod - def from_rollout_config(cls, actor_rollout_wg, rollout_config, rollout_mode="async"): + def from_rollout_config( + cls, actor_rollout_wg, rollout_config, rollout_mode="async" + ): agent_config = AgentActorConfig() - for key in getattr(rollout_config, 'agent', {}).keys(): + for key in getattr(rollout_config, "agent", {}).keys(): if key in agent_config.__dict__.keys(): setattr(agent_config, key, rollout_config.agent[key]) - setattr(agent_config, 'n', rollout_config.rollout.n) - setattr(agent_config, 'max_model_len', rollout_config.rollout.max_model_len) + setattr(agent_config, "n", rollout_config.rollout.n) + setattr(agent_config, "max_model_len", rollout_config.rollout.max_model_len) model_path = rollout_config.model.path agent_config.rollout_mode = rollout_mode print(f"AgentAsyncActorRolloutRefWorker: {agent_config}") agent_actor_manager = cls(model_path, actor_rollout_wg, agent_config) return agent_actor_manager - + def _batch_tokenize(self, responses: List[str]) -> torch.Tensor: """Tokenize a batch of responses.""" return self.tokenizer( - responses, - add_special_tokens=False, - return_tensors='pt', - padding="longest" - )['input_ids'] - + responses, add_special_tokens=False, return_tensors="pt", padding="longest" + )["input_ids"] + def repeat_inputs_by_n(self, inputs: DataProto, n=None, force=False): """ this version verl do not repeat the input by n times, so we manually repeat the input by n times @@ -152,8 +188,10 @@ def repeat_inputs_by_n(self, inputs: DataProto, n=None, force=False): # we manually repeat the input by n times if needed since every trajectory is independent do_sample = inputs.meta_info.get("do_sample", True) - assert 'traj_ids' in inputs.non_tensor_batch, "traj_ids should be claimed univerally in the ray trainer" - ori_len = len(inputs.batch['input_ids']) + assert ( + "traj_ids" in inputs.non_tensor_batch + ), "traj_ids should be claimed univerally in the ray trainer" + ori_len = len(inputs.batch["input_ids"]) if not do_sample: n = 1 else: @@ -162,22 +200,29 @@ def repeat_inputs_by_n(self, inputs: DataProto, n=None, force=False): n = self.config.val_kwargs.n else: n = self.config.n - + inputs = inputs.repeat(n, interleave=True) # add "_{i}" for each trajectory to the traj_ids for i in range(ori_len): for j in range(n): - inputs.non_tensor_batch['traj_ids'][i*n+j] += f"_{j}" + inputs.non_tensor_batch["traj_ids"][i * n + j] += f"_{j}" # deepcopy to avoid reference bug for key in inputs.non_tensor_batch.keys(): - if key == 'traj_ids': + if key == "traj_ids": continue # # check if it's the same reference as the inputs.non_tensor_batch[key][i] - inputs.non_tensor_batch[key][i*n+j] = nested_copy(inputs.non_tensor_batch[key][i*n]) - inputs.meta_info['is_repeated_by_n'] = True + inputs.non_tensor_batch[key][i * n + j] = nested_copy( + inputs.non_tensor_batch[key][i * n] + ) + inputs.meta_info["is_repeated_by_n"] = True return inputs - async def _postprocess_responses(self, responses: Union[torch.Tensor, List[str]], action_step: int, rollout_messages: list) -> torch.Tensor: + async def _postprocess_responses( + self, + responses: Union[torch.Tensor, List[str]], + action_step: int, + rollout_messages: list, + ) -> torch.Tensor: """Process responses to stop at python operation or answer operation. Args: responses (Union[torch.Tensor, List[str]]): Responses from the model, either as a tensor or a list of strings. of length sum(active_mask), which <= batch_size @@ -195,8 +240,7 @@ async def _postprocess_responses(self, responses: Union[torch.Tensor, List[str]] async with self.tokenizer_lock: if isinstance(responses, torch.Tensor): responses_str = self.tokenizer.batch_decode( - responses, - skip_special_tokens=True + responses, skip_special_tokens=True ) else: responses_str = responses @@ -206,34 +250,46 @@ async def _postprocess_responses(self, responses: Union[torch.Tensor, List[str]] rollout_messages[i].update_rollout_messages( { "role": self.config.assistant_role, - "content": responses_str[i] + "content": responses_str[i], } ) - + for i in range(len(responses_str)): # check if the response contains action stop tokens has_action = False for j in range(len(self.action_stop_tokens)): if self.action_stop_tokens[j] in responses_str[i]: - responses_str[i] = responses_str[i].split(self.action_stop_tokens[j])[0] + self.action_stop_tokens[j] + responses_str[i] = ( + responses_str[i].split(self.action_stop_tokens[j])[0] + + self.action_stop_tokens[j] + ) has_action = True break - + # judge whether do action or not if action_step >= self.config.min_turns: # do action if there are action stop tokens in the response - do_action = has_action or (self.config.enable_mtrl and not self.action_stop_tokens) + do_action = has_action or ( + self.config.enable_mtrl and not self.action_stop_tokens + ) else: # always do action, decided by the server about whether an action stops do_action = True if self.action_stop_tokens and not has_action: # force add a action stop token for those responses that do not have action stop tokens - turn_end_token_idx = responses_str[i].rfind(self.config.turn_end_token) + turn_end_token_idx = responses_str[i].rfind( + self.config.turn_end_token + ) if turn_end_token_idx != -1: - responses_str[i] = responses_str[i][:turn_end_token_idx] + self.action_stop_tokens[0] + responses_str[i] = ( + responses_str[i][:turn_end_token_idx] + + self.action_stop_tokens[0] + ) else: - responses_str[i] = responses_str[i] + self.action_stop_tokens[0] - + responses_str[i] = ( + responses_str[i] + self.action_stop_tokens[0] + ) + # now if do action, responses_str[i] should end with a action stop token, if not do action, we use the original response if do_action: if self.config.enable_mtrl: @@ -241,13 +297,23 @@ async def _postprocess_responses(self, responses: Union[torch.Tensor, List[str]] responses_str[i] += self.config.turn_end_token else: # preserve eos token - responses_str[i] = self.tokenizer.decode(responses[i][:effective_lens[i]], skip_special_tokens=False) - do_actions.append(do_action) + responses_str[i] = self.tokenizer.decode( + responses[i][: effective_lens[i]], skip_special_tokens=False + ) + do_actions.append(do_action) responses = self._batch_tokenize(responses_str).to(torch.int64) return responses, responses_str, do_actions, rollout_messages - async def _process_next_obs(self, next_obs: List[str], dones: List[bool], valid_action: List[bool], finishs: List[bool], tool_interact_info: List[dict], rollings: DataProto) -> Tuple[torch.Tensor, List[dict]]: + async def _process_next_obs( + self, + next_obs: List[str], + dones: List[bool], + valid_action: List[bool], + finishs: List[bool], + tool_interact_info: List[dict], + rollings: DataProto, + ) -> Tuple[torch.Tensor, List[dict]]: """Process next observations from environment. Args: next_obs (List[str]): List of next observations, only the text part. @@ -260,70 +326,91 @@ async def _process_next_obs(self, next_obs: List[str], dones: List[bool], valid_ next_obs_ids (torch.Tensor): Tokenized next observations. rollings (DataProto): Updated rolling state with new observations. """ - has_multi_modal_data = "multi_modal_data" in rollings.non_tensor_batch and rollings.non_tensor_batch['multi_modal_data'] is not None + has_multi_modal_data = ( + "multi_modal_data" in rollings.non_tensor_batch + and rollings.non_tensor_batch["multi_modal_data"] is not None + ) mm_data_list = None async with self.tokenizer_lock: mtrl_sep = self.config.mtrl_sep next_obs = [obs if not done else "" for obs, done in zip(next_obs, dones)] - if self.config.truncate_obs_side == 'left': + if self.config.truncate_obs_side == "left": next_obs_ids = self.tokenizer( next_obs, - padding='longest', - return_tensors='pt', + padding="longest", + return_tensors="pt", add_special_tokens=False, # Prevents adding special tokens - padding_side='left', - )['input_ids'].to(torch.int64) + padding_side="left", + )["input_ids"].to(torch.int64) if next_obs_ids.shape[1] > self.config.max_obs_length: - logger.warning(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.config.max_obs_length}") - next_obs_ids = next_obs_ids[:, -self.config.max_obs_length:] - elif self.config.truncate_obs_side == 'right': + logger.warning( + f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.config.max_obs_length}" + ) + next_obs_ids = next_obs_ids[:, -self.config.max_obs_length :] + elif self.config.truncate_obs_side == "right": next_obs_ids = self.tokenizer( next_obs, - padding='longest', - return_tensors='pt', + padding="longest", + return_tensors="pt", add_special_tokens=False, # Prevents adding special tokens - padding_side='right', - )['input_ids'].to(torch.int64) + padding_side="right", + )["input_ids"].to(torch.int64) if next_obs_ids.shape[1] > self.config.max_obs_length: - logger.warning(f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.config.max_obs_length}") - next_obs_ids = next_obs_ids[:, :self.config.max_obs_length] + logger.warning( + f"[WARNING] OBSERVATION TOO LONG, CONSIDER CHANGING YOUR CONFIG, {next_obs_ids.shape[1]} & {self.config.max_obs_length}" + ) + next_obs_ids = next_obs_ids[:, : self.config.max_obs_length] else: - raise ValueError(f"Invalid truncate_obs_side: {self.config.truncate_obs_side}") + raise ValueError( + f"Invalid truncate_obs_side: {self.config.truncate_obs_side}" + ) next_obs = self.tokenizer.batch_decode( - next_obs_ids, - skip_special_tokens=True + next_obs_ids, skip_special_tokens=True ) if not has_multi_modal_data: - + if self.config.enable_mtrl: processed_next_obs = [] for i in range(len(next_obs)): if finishs[i] or dones[i]: # do action is false - assert next_obs[i] == "", f"next_obs should be empty when finishs is True, but got {next_obs[i]}" + assert ( + next_obs[i] == "" + ), f"next_obs should be empty when finishs is True, but got {next_obs[i]}" processed_next_obs.append("") elif valid_action[i]: processed_next_obs.append(mtrl_sep.format(obs=next_obs[i])) else: - processed_next_obs.append(mtrl_sep.format(obs="Your action is not valid, please check the format and try again." + next_obs[i])) + processed_next_obs.append( + mtrl_sep.format( + obs="Your action is not valid, please check the format and try again." + + next_obs[i] + ) + ) next_obs = processed_next_obs next_obs_ids = self.tokenizer( next_obs, - padding='longest', - return_tensors='pt', + padding="longest", + return_tensors="pt", add_special_tokens=False, # Prevents adding special tokens - )['input_ids'].to(torch.int64) + )["input_ids"].to(torch.int64) # update rollout messages with next_obs if "rollout_messages" in rollings.non_tensor_batch: for i in range(len(next_obs)): if next_obs[i]: - rollings.non_tensor_batch['rollout_messages'][i].update_rollout_messages( + rollings.non_tensor_batch["rollout_messages"][ + i + ].update_rollout_messages( { - "role": self.config.mtrl_role if self.config.enable_mtrl else self.config.assistant_role, - "content": next_obs[i] + "role": ( + self.config.mtrl_role + if self.config.enable_mtrl + else self.config.assistant_role + ), + "content": next_obs[i], } ) else: @@ -331,39 +418,64 @@ async def _process_next_obs(self, next_obs: List[str], dones: List[bool], valid_ raw_prompts = [] import traceback - + for k, tool_interact_info_k in enumerate(tool_interact_info): try: multi_modal_data = {} - next_obs_image = tool_interact_info_k.get('image', []) + next_obs_image = tool_interact_info_k.get("image", []) if not isinstance(next_obs_image, list): next_obs_image = [next_obs_image] - next_obs_image = [decode_image_url(img) for img in next_obs_image] + next_obs_image = [ + decode_image_url(img) for img in next_obs_image + ] multi_modal_data["image"] = next_obs_image - - next_obs_video = tool_interact_info_k.get('video', []) + + next_obs_video = tool_interact_info_k.get("video", []) if not isinstance(next_obs_video, list): next_obs_video = [next_obs_video] - next_obs_video = [decode_video_url(video) for video in next_obs_video] - multi_modal_data["video"] = [video.numpy() for video in next_obs_video] + next_obs_video = [ + decode_video_url(video) for video in next_obs_video + ] + multi_modal_data["video"] = [ + video.numpy() for video in next_obs_video + ] # add additional and