How to Add Custom Risk Management Rules to the AI Hedge Fund Agent Workflow

You can add custom risk management rules to the virattt/ai-hedge-fund workflow by either extending the existing risk_management_agent in src/agents/risk_manager.py or inserting an additional risk node into the state graph defined in src/main.py.

The virattt/ai-hedge-fund system employs a state-graph architecture built on LangGraph, where each step of the investment workflow operates as a node that reads from and writes to a shared AgentState. To add custom risk management rules, you modify the graph topology in create_workflow to include your own risk agent that enforces position limits, VaR constraints, or sector caps before the portfolio manager executes trades.

Understanding the Risk Management Architecture

The workflow graph is constructed in src/main.py within the create_workflow function. Each node receives the shared AgentState dictionary and writes analytical outputs to state["data"]["analyst_signals"].

The built-in risk manager lives in src/agents/risk_manager.py. It calculates volatility-adjusted position limits and correlation multipliers, storing results under the key "risk_management_agent" in the shared state. Because the portfolio manager reads from state["data"]["analyst_signals"] to determine final position sizes, any custom risk logic that writes to this location will automatically constrain downstream decisions.

The default agent performs the following steps:

  1. Pulls price history for all tickers using get_prices from src/tools/api.py.
  2. Computes volatility metrics via calculate_volatility_metrics.
  3. Builds a correlation matrix across the ticker universe.
  4. Calculates a volatility-adjusted limit for each ticker using calculate_volatility_adjusted_limit.
  5. Adjusts the limit by a correlation multiplier via calculate_correlation_multiplier.
  6. Stores the final per-ticker limits in state["data"]["analyst_signals"]["risk_management_agent"].

Method 1: Replace the Built-in Risk Agent

To completely override the default volatility-based risk logic, create a custom agent file and substitute it for the original node in create_workflow.

First, create src/agents/custom_risk_manager.py:

from langchain_core.messages import HumanMessage
from src.graph.state import AgentState, show_agent_reasoning
from src.utils.progress import progress
from src.tools.api import get_prices, prices_to_df
import json
import numpy as np
import pandas as pd

def custom_risk_management_agent(state: AgentState, agent_id: str = "custom_risk_manager"):
    """
    Extends the original risk agent with a 5% VaR hard cap.
    """
    # Re-use existing volatility logic

    from src.agents.risk_manager import (
        calculate_volatility_metrics,
        calculate_volatility_adjusted_limit,
        calculate_correlation_multiplier,
    )
    
    # ... (existing price fetching and volatility calculations) ...

    
    # Custom addition: Enforce 5% VaR cap

    total_portfolio_value = state["data"]["portfolio"]["cash"]
    for pos in state["data"]["portfolio"]["positions"].values():
        total_portfolio_value += (pos["long"] - pos["short"]) * pos.get("current_price", 0)
    
    var_cap_pct = 0.05
    
    for ticker, analysis in risk_analysis.items():
        existing_limit = analysis["remaining_position_limit"]
        var_cap = total_portfolio_value * var_cap_pct
        final_limit = min(existing_limit, var_cap)
        analysis["remaining_position_limit"] = float(final_limit)
        analysis["reasoning"]["var_cap"] = f"VaR hard-cap applied: {var_cap_pct:.0%} of portfolio"
    
    # Finalize state update

    progress.update_status(agent_id, None, "Done")
    message = HumanMessage(content=json.dumps(risk_analysis), name=agent_id)
    
    if state["metadata"]["show_reasoning"]:
        show_agent_reasoning(risk_analysis, "Custom Risk Management Agent")
    
    state["data"]["analyst_signals"][agent_id] = risk_analysis
    return {"messages": state["messages"] + [message], "data": state["data"]}

Then, wire it into src/main.py:

from src.agents.custom_risk_manager import custom_risk_management_agent

# Replace the default node

workflow.add_node("risk_management_agent", custom_risk_management_agent)

Method 2: Insert an Additional Risk Node

To layer custom logic on top of the existing volatility calculations—such as sector exposure caps—add a secondary node that runs after the default risk agent.

Create src/agents/sector_cap_risk.py:

from src.graph.state import AgentState
from src.utils.progress import progress
import json

def sector_cap_risk_agent(state: AgentState, agent_id: str = "sector_cap_risk"):
    """
    Enforces a 10% maximum exposure per sector after volatility limits are applied.
    """
    risk_signals = state["data"]["analyst_signals"]["risk_management_agent"]
    dummy_sector_map = {"AAPL": "Tech", "MSFT": "Tech", "TSLA": "Auto"}
    
    sector_allocations = {}
    for ticker, info in risk_signals.items():
        sector = dummy_sector_map.get(ticker, "Other")
        sector_allocations.setdefault(sector, 0.0)
        sector_allocations[sector] += info["remaining_position_limit"]
    
    total_portfolio = state["data"]["portfolio"]["cash"]
    
    for ticker, info in risk_signals.items():
        sector = dummy_sector_map.get(ticker, "Other")
        max_sector_limit = 0.10 * total_portfolio
        
        if sector_allocations[sector] > max_sector_limit:
            excess = sector_allocations[sector] - max_sector_limit
            reduction = excess * (info["remaining_position_limit"] / sector_allocations[sector])
            new_limit = max(info["remaining_position_limit"] - reduction, 0.0)
            info["remaining_position_limit"] = float(new_limit)
            info["reasoning"]["sector_cap"] = f"Sector ({sector}) cap applied"
            sector_allocations[sector] = max_sector_limit
    
    progress.update_status(agent_id, None, "Sector caps applied")
    state["data"]["analyst_signals"][agent_id] = risk_signals
    return {"messages": state["messages"], "data": state["data"]}

Modify the graph in src/main.py to chain the nodes:

from src.agents.sector_cap_risk import sector_cap_risk_agent

workflow.add_node("risk_management_agent", risk_management_agent)
workflow.add_node("sector_cap_risk", sector_cap_risk_agent)

# Connect pipeline: analysts → risk → sector cap → portfolio manager

workflow.add_edge("risk_management_agent", "sector_cap_risk")
workflow.add_edge("sector_cap_risk", "portfolio_manager")

Alternative Backend Graph Configuration

If you run the system via the FastAPI backend service, the same wiring logic appears in app/backend/services/graph.py. The risk_manager_nodes handling in this file mirrors the create_workflow function in src/main.py. Update this file instead if you are deploying the hedge fund as an API service rather than running the CLI.

Key Files for Custom Risk Implementation

  • src/agents/risk_manager.py – Contains calculate_volatility_metrics, calculate_correlation_multiplier, and the original risk_management_agent function.
  • src/main.py – Houses create_workflow where you register nodes and define edges between analysts, risk managers, and the portfolio manager.
  • src/graph/state.py – Defines the AgentState TypedDict structure that all agents use for data exchange.
  • app/backend/services/graph.py – Backend service equivalent of create_workflow for API deployments.

Summary

  • The virattt/ai-hedge-fund workflow uses a LangGraph state graph where risk agents write position limits to state["data"]["analyst_signals"].
  • You can replace the default risk_management_agent by importing a custom function and passing it to workflow.add_node() in src/main.py.
  • You can chain additional risk nodes—such as sector cap or VaR limit agents—by adding edges between the default risk node and the portfolio manager.
  • The portfolio manager automatically reads from state["data"]["analyst_signals"], so no downstream modifications are required when you add or replace risk nodes.

Frequently Asked Questions

Do I need to modify the portfolio manager to recognize custom risk signals?

No. The portfolio manager reads all signals from state["data"]["analyst_signals"] using the agent ID as the key. As long as your custom risk agent writes its output to this dictionary using its agent_id (e.g., "custom_risk_manager"), the portfolio manager will see and respect the limits without code changes.

Can I run multiple risk agents in parallel instead of sequentially?

Yes. The graph supports parallel execution. Instead of using workflow.add_edge("risk_management_agent", "sector_cap_risk"), you can use workflow.add_edge("analyst_1", ["risk_management_agent", "sector_cap_risk"]) to fan out to multiple risk agents simultaneously. However, you must then implement a reduction step to merge conflicting limits before the portfolio manager node.

Where should I store sector mappings or external data for risk calculations?

Store static mappings (like ticker-to-sector tables) in a JSON or CSV file within your agent directory (e.g., src/agents/data/sector_map.json) and load them at agent initialization. For dynamic data, extend the AgentState definition in src/graph/state.py to include a new field, or fetch external data via the existing get_prices utility in src/tools/api.py.

How do I test custom risk rules without executing live trades?

Run the CLI in dry-run mode or isolate your custom agent in a unit test. Import AgentState from src/graph/state.py, construct a mock state dictionary with sample portfolio data, and pass it to your risk agent function to verify the output limits. The show_agent_reasoning utility in src/graph/state.py will print the decision trace for manual verification.

Have a question about this repo?

These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:

Share the following with your agent to get started:
curl -s "https://instagit.com/install.md"

Works with
Claude Codex Cursor VS Code OpenClaw Any MCP Client

Maintain an open-source project? Get it listed too →