# How to Add Custom Risk Management Rules to the AI Hedge Fund Agent Workflow

> Learn to add custom risk management rules to the virattt/ai-hedge-fund agent workflow. Enhance your trading strategy by extending the risk manager or adding new nodes.

- Repository: [Virat Singh/ai-hedge-fund](https://github.com/virattt/ai-hedge-fund)
- Tags: how-to-guide
- Published: 2026-03-09

---

**You can add custom risk management rules to the virattt/ai-hedge-fund workflow by either extending the existing `risk_management_agent` in [`src/agents/risk_manager.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/agents/risk_manager.py) or inserting an additional risk node into the state graph defined in [`src/main.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/main.py).**

The virattt/ai-hedge-fund system employs a **state-graph architecture** built on LangGraph, where each step of the investment workflow operates as a node that reads from and writes to a shared `AgentState`. To add custom risk management rules, you modify the graph topology in `create_workflow` to include your own risk agent that enforces position limits, VaR constraints, or sector caps before the portfolio manager executes trades.

## Understanding the Risk Management Architecture

The workflow graph is constructed in [`src/main.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/main.py) within the `create_workflow` function. Each node receives the shared `AgentState` dictionary and writes analytical outputs to `state["data"]["analyst_signals"]`.

The built-in risk manager lives in [`src/agents/risk_manager.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/agents/risk_manager.py). It calculates volatility-adjusted position limits and correlation multipliers, storing results under the key `"risk_management_agent"` in the shared state. Because the portfolio manager reads from `state["data"]["analyst_signals"]` to determine final position sizes, any custom risk logic that writes to this location will automatically constrain downstream decisions.

The default agent performs the following steps:

1. Pulls price history for all tickers using `get_prices` from [`src/tools/api.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/tools/api.py).
2. Computes **volatility metrics** via `calculate_volatility_metrics`.
3. Builds a **correlation matrix** across the ticker universe.
4. Calculates a **volatility-adjusted limit** for each ticker using `calculate_volatility_adjusted_limit`.
5. Adjusts the limit by a **correlation multiplier** via `calculate_correlation_multiplier`.
6. Stores the final per-ticker limits in `state["data"]["analyst_signals"]["risk_management_agent"]`.

## Method 1: Replace the Built-in Risk Agent

To completely override the default volatility-based risk logic, create a custom agent file and substitute it for the original node in `create_workflow`.

First, create [`src/agents/custom_risk_manager.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/agents/custom_risk_manager.py):

```python
from langchain_core.messages import HumanMessage
from src.graph.state import AgentState, show_agent_reasoning
from src.utils.progress import progress
from src.tools.api import get_prices, prices_to_df
import json
import numpy as np
import pandas as pd

def custom_risk_management_agent(state: AgentState, agent_id: str = "custom_risk_manager"):
    """
    Extends the original risk agent with a 5% VaR hard cap.
    """
    # Re-use existing volatility logic

    from src.agents.risk_manager import (
        calculate_volatility_metrics,
        calculate_volatility_adjusted_limit,
        calculate_correlation_multiplier,
    )
    
    # ... (existing price fetching and volatility calculations) ...

    
    # Custom addition: Enforce 5% VaR cap

    total_portfolio_value = state["data"]["portfolio"]["cash"]
    for pos in state["data"]["portfolio"]["positions"].values():
        total_portfolio_value += (pos["long"] - pos["short"]) * pos.get("current_price", 0)
    
    var_cap_pct = 0.05
    
    for ticker, analysis in risk_analysis.items():
        existing_limit = analysis["remaining_position_limit"]
        var_cap = total_portfolio_value * var_cap_pct
        final_limit = min(existing_limit, var_cap)
        analysis["remaining_position_limit"] = float(final_limit)
        analysis["reasoning"]["var_cap"] = f"VaR hard-cap applied: {var_cap_pct:.0%} of portfolio"
    
    # Finalize state update

    progress.update_status(agent_id, None, "Done")
    message = HumanMessage(content=json.dumps(risk_analysis), name=agent_id)
    
    if state["metadata"]["show_reasoning"]:
        show_agent_reasoning(risk_analysis, "Custom Risk Management Agent")
    
    state["data"]["analyst_signals"][agent_id] = risk_analysis
    return {"messages": state["messages"] + [message], "data": state["data"]}

```

Then, wire it into [`src/main.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/main.py):

```python
from src.agents.custom_risk_manager import custom_risk_management_agent

# Replace the default node

workflow.add_node("risk_management_agent", custom_risk_management_agent)

```

## Method 2: Insert an Additional Risk Node

To layer custom logic on top of the existing volatility calculations—such as sector exposure caps—add a secondary node that runs after the default risk agent.

Create [`src/agents/sector_cap_risk.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/agents/sector_cap_risk.py):

```python
from src.graph.state import AgentState
from src.utils.progress import progress
import json

def sector_cap_risk_agent(state: AgentState, agent_id: str = "sector_cap_risk"):
    """
    Enforces a 10% maximum exposure per sector after volatility limits are applied.
    """
    risk_signals = state["data"]["analyst_signals"]["risk_management_agent"]
    dummy_sector_map = {"AAPL": "Tech", "MSFT": "Tech", "TSLA": "Auto"}
    
    sector_allocations = {}
    for ticker, info in risk_signals.items():
        sector = dummy_sector_map.get(ticker, "Other")
        sector_allocations.setdefault(sector, 0.0)
        sector_allocations[sector] += info["remaining_position_limit"]
    
    total_portfolio = state["data"]["portfolio"]["cash"]
    
    for ticker, info in risk_signals.items():
        sector = dummy_sector_map.get(ticker, "Other")
        max_sector_limit = 0.10 * total_portfolio
        
        if sector_allocations[sector] > max_sector_limit:
            excess = sector_allocations[sector] - max_sector_limit
            reduction = excess * (info["remaining_position_limit"] / sector_allocations[sector])
            new_limit = max(info["remaining_position_limit"] - reduction, 0.0)
            info["remaining_position_limit"] = float(new_limit)
            info["reasoning"]["sector_cap"] = f"Sector ({sector}) cap applied"
            sector_allocations[sector] = max_sector_limit
    
    progress.update_status(agent_id, None, "Sector caps applied")
    state["data"]["analyst_signals"][agent_id] = risk_signals
    return {"messages": state["messages"], "data": state["data"]}

```

Modify the graph in [`src/main.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/main.py) to chain the nodes:

```python
from src.agents.sector_cap_risk import sector_cap_risk_agent

workflow.add_node("risk_management_agent", risk_management_agent)
workflow.add_node("sector_cap_risk", sector_cap_risk_agent)

# Connect pipeline: analysts → risk → sector cap → portfolio manager

workflow.add_edge("risk_management_agent", "sector_cap_risk")
workflow.add_edge("sector_cap_risk", "portfolio_manager")

```

## Alternative Backend Graph Configuration

If you run the system via the FastAPI backend service, the same wiring logic appears in [`app/backend/services/graph.py`](https://github.com/virattt/ai-hedge-fund/blob/main/app/backend/services/graph.py). The `risk_manager_nodes` handling in this file mirrors the `create_workflow` function in [`src/main.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/main.py). Update this file instead if you are deploying the hedge fund as an API service rather than running the CLI.

## Key Files for Custom Risk Implementation

- **[`src/agents/risk_manager.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/agents/risk_manager.py)** – Contains `calculate_volatility_metrics`, `calculate_correlation_multiplier`, and the original `risk_management_agent` function.
- **[`src/main.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/main.py)** – Houses `create_workflow` where you register nodes and define edges between analysts, risk managers, and the portfolio manager.
- **[`src/graph/state.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/graph/state.py)** – Defines the `AgentState` TypedDict structure that all agents use for data exchange.
- **[`app/backend/services/graph.py`](https://github.com/virattt/ai-hedge-fund/blob/main/app/backend/services/graph.py)** – Backend service equivalent of `create_workflow` for API deployments.

## Summary

- The virattt/ai-hedge-fund workflow uses a **LangGraph state graph** where risk agents write position limits to `state["data"]["analyst_signals"]`.
- You can **replace** the default `risk_management_agent` by importing a custom function and passing it to `workflow.add_node()` in [`src/main.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/main.py).
- You can **chain** additional risk nodes—such as sector cap or VaR limit agents—by adding edges between the default risk node and the portfolio manager.
- The portfolio manager automatically reads from `state["data"]["analyst_signals"]`, so no downstream modifications are required when you add or replace risk nodes.

## Frequently Asked Questions

### Do I need to modify the portfolio manager to recognize custom risk signals?

No. The portfolio manager reads all signals from `state["data"]["analyst_signals"]` using the agent ID as the key. As long as your custom risk agent writes its output to this dictionary using its `agent_id` (e.g., `"custom_risk_manager"`), the portfolio manager will see and respect the limits without code changes.

### Can I run multiple risk agents in parallel instead of sequentially?

Yes. The graph supports parallel execution. Instead of using `workflow.add_edge("risk_management_agent", "sector_cap_risk")`, you can use `workflow.add_edge("analyst_1", ["risk_management_agent", "sector_cap_risk"])` to fan out to multiple risk agents simultaneously. However, you must then implement a reduction step to merge conflicting limits before the portfolio manager node.

### Where should I store sector mappings or external data for risk calculations?

Store static mappings (like ticker-to-sector tables) in a JSON or CSV file within your agent directory (e.g., [`src/agents/data/sector_map.json`](https://github.com/virattt/ai-hedge-fund/blob/main/src/agents/data/sector_map.json)) and load them at agent initialization. For dynamic data, extend the `AgentState` definition in [`src/graph/state.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/graph/state.py) to include a new field, or fetch external data via the existing `get_prices` utility in [`src/tools/api.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/tools/api.py).

### How do I test custom risk rules without executing live trades?

Run the CLI in dry-run mode or isolate your custom agent in a unit test. Import `AgentState` from [`src/graph/state.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/graph/state.py), construct a mock state dictionary with sample portfolio data, and pass it to your risk agent function to verify the output limits. The `show_agent_reasoning` utility in [`src/graph/state.py`](https://github.com/virattt/ai-hedge-fund/blob/main/src/graph/state.py) will print the decision trace for manual verification.