How to Add Custom Risk Management Rules to the AI Hedge Fund Agent Workflow
You can add custom risk management rules to the virattt/ai-hedge-fund workflow by either extending the existing risk_management_agent in src/agents/risk_manager.py or inserting an additional risk node into the state graph defined in src/main.py.
The virattt/ai-hedge-fund system employs a state-graph architecture built on LangGraph, where each step of the investment workflow operates as a node that reads from and writes to a shared AgentState. To add custom risk management rules, you modify the graph topology in create_workflow to include your own risk agent that enforces position limits, VaR constraints, or sector caps before the portfolio manager executes trades.
Understanding the Risk Management Architecture
The workflow graph is constructed in src/main.py within the create_workflow function. Each node receives the shared AgentState dictionary and writes analytical outputs to state["data"]["analyst_signals"].
The built-in risk manager lives in src/agents/risk_manager.py. It calculates volatility-adjusted position limits and correlation multipliers, storing results under the key "risk_management_agent" in the shared state. Because the portfolio manager reads from state["data"]["analyst_signals"] to determine final position sizes, any custom risk logic that writes to this location will automatically constrain downstream decisions.
The default agent performs the following steps:
- Pulls price history for all tickers using
get_pricesfromsrc/tools/api.py. - Computes volatility metrics via
calculate_volatility_metrics. - Builds a correlation matrix across the ticker universe.
- Calculates a volatility-adjusted limit for each ticker using
calculate_volatility_adjusted_limit. - Adjusts the limit by a correlation multiplier via
calculate_correlation_multiplier. - Stores the final per-ticker limits in
state["data"]["analyst_signals"]["risk_management_agent"].
Method 1: Replace the Built-in Risk Agent
To completely override the default volatility-based risk logic, create a custom agent file and substitute it for the original node in create_workflow.
First, create src/agents/custom_risk_manager.py:
from langchain_core.messages import HumanMessage
from src.graph.state import AgentState, show_agent_reasoning
from src.utils.progress import progress
from src.tools.api import get_prices, prices_to_df
import json
import numpy as np
import pandas as pd
def custom_risk_management_agent(state: AgentState, agent_id: str = "custom_risk_manager"):
"""
Extends the original risk agent with a 5% VaR hard cap.
"""
# Re-use existing volatility logic
from src.agents.risk_manager import (
calculate_volatility_metrics,
calculate_volatility_adjusted_limit,
calculate_correlation_multiplier,
)
# ... (existing price fetching and volatility calculations) ...
# Custom addition: Enforce 5% VaR cap
total_portfolio_value = state["data"]["portfolio"]["cash"]
for pos in state["data"]["portfolio"]["positions"].values():
total_portfolio_value += (pos["long"] - pos["short"]) * pos.get("current_price", 0)
var_cap_pct = 0.05
for ticker, analysis in risk_analysis.items():
existing_limit = analysis["remaining_position_limit"]
var_cap = total_portfolio_value * var_cap_pct
final_limit = min(existing_limit, var_cap)
analysis["remaining_position_limit"] = float(final_limit)
analysis["reasoning"]["var_cap"] = f"VaR hard-cap applied: {var_cap_pct:.0%} of portfolio"
# Finalize state update
progress.update_status(agent_id, None, "Done")
message = HumanMessage(content=json.dumps(risk_analysis), name=agent_id)
if state["metadata"]["show_reasoning"]:
show_agent_reasoning(risk_analysis, "Custom Risk Management Agent")
state["data"]["analyst_signals"][agent_id] = risk_analysis
return {"messages": state["messages"] + [message], "data": state["data"]}
Then, wire it into src/main.py:
from src.agents.custom_risk_manager import custom_risk_management_agent
# Replace the default node
workflow.add_node("risk_management_agent", custom_risk_management_agent)
Method 2: Insert an Additional Risk Node
To layer custom logic on top of the existing volatility calculations—such as sector exposure caps—add a secondary node that runs after the default risk agent.
Create src/agents/sector_cap_risk.py:
from src.graph.state import AgentState
from src.utils.progress import progress
import json
def sector_cap_risk_agent(state: AgentState, agent_id: str = "sector_cap_risk"):
"""
Enforces a 10% maximum exposure per sector after volatility limits are applied.
"""
risk_signals = state["data"]["analyst_signals"]["risk_management_agent"]
dummy_sector_map = {"AAPL": "Tech", "MSFT": "Tech", "TSLA": "Auto"}
sector_allocations = {}
for ticker, info in risk_signals.items():
sector = dummy_sector_map.get(ticker, "Other")
sector_allocations.setdefault(sector, 0.0)
sector_allocations[sector] += info["remaining_position_limit"]
total_portfolio = state["data"]["portfolio"]["cash"]
for ticker, info in risk_signals.items():
sector = dummy_sector_map.get(ticker, "Other")
max_sector_limit = 0.10 * total_portfolio
if sector_allocations[sector] > max_sector_limit:
excess = sector_allocations[sector] - max_sector_limit
reduction = excess * (info["remaining_position_limit"] / sector_allocations[sector])
new_limit = max(info["remaining_position_limit"] - reduction, 0.0)
info["remaining_position_limit"] = float(new_limit)
info["reasoning"]["sector_cap"] = f"Sector ({sector}) cap applied"
sector_allocations[sector] = max_sector_limit
progress.update_status(agent_id, None, "Sector caps applied")
state["data"]["analyst_signals"][agent_id] = risk_signals
return {"messages": state["messages"], "data": state["data"]}
Modify the graph in src/main.py to chain the nodes:
from src.agents.sector_cap_risk import sector_cap_risk_agent
workflow.add_node("risk_management_agent", risk_management_agent)
workflow.add_node("sector_cap_risk", sector_cap_risk_agent)
# Connect pipeline: analysts → risk → sector cap → portfolio manager
workflow.add_edge("risk_management_agent", "sector_cap_risk")
workflow.add_edge("sector_cap_risk", "portfolio_manager")
Alternative Backend Graph Configuration
If you run the system via the FastAPI backend service, the same wiring logic appears in app/backend/services/graph.py. The risk_manager_nodes handling in this file mirrors the create_workflow function in src/main.py. Update this file instead if you are deploying the hedge fund as an API service rather than running the CLI.
Key Files for Custom Risk Implementation
src/agents/risk_manager.py– Containscalculate_volatility_metrics,calculate_correlation_multiplier, and the originalrisk_management_agentfunction.src/main.py– Housescreate_workflowwhere you register nodes and define edges between analysts, risk managers, and the portfolio manager.src/graph/state.py– Defines theAgentStateTypedDict structure that all agents use for data exchange.app/backend/services/graph.py– Backend service equivalent ofcreate_workflowfor API deployments.
Summary
- The virattt/ai-hedge-fund workflow uses a LangGraph state graph where risk agents write position limits to
state["data"]["analyst_signals"]. - You can replace the default
risk_management_agentby importing a custom function and passing it toworkflow.add_node()insrc/main.py. - You can chain additional risk nodes—such as sector cap or VaR limit agents—by adding edges between the default risk node and the portfolio manager.
- The portfolio manager automatically reads from
state["data"]["analyst_signals"], so no downstream modifications are required when you add or replace risk nodes.
Frequently Asked Questions
Do I need to modify the portfolio manager to recognize custom risk signals?
No. The portfolio manager reads all signals from state["data"]["analyst_signals"] using the agent ID as the key. As long as your custom risk agent writes its output to this dictionary using its agent_id (e.g., "custom_risk_manager"), the portfolio manager will see and respect the limits without code changes.
Can I run multiple risk agents in parallel instead of sequentially?
Yes. The graph supports parallel execution. Instead of using workflow.add_edge("risk_management_agent", "sector_cap_risk"), you can use workflow.add_edge("analyst_1", ["risk_management_agent", "sector_cap_risk"]) to fan out to multiple risk agents simultaneously. However, you must then implement a reduction step to merge conflicting limits before the portfolio manager node.
Where should I store sector mappings or external data for risk calculations?
Store static mappings (like ticker-to-sector tables) in a JSON or CSV file within your agent directory (e.g., src/agents/data/sector_map.json) and load them at agent initialization. For dynamic data, extend the AgentState definition in src/graph/state.py to include a new field, or fetch external data via the existing get_prices utility in src/tools/api.py.
How do I test custom risk rules without executing live trades?
Run the CLI in dry-run mode or isolate your custom agent in a unit test. Import AgentState from src/graph/state.py, construct a mock state dictionary with sample portfolio data, and pass it to your risk agent function to verify the output limits. The show_agent_reasoning utility in src/graph/state.py will print the decision trace for manual verification.
Have a question about this repo?
These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:
curl -s "https://instagit.com/install.md" Maintain an open-source project? Get it listed too →