How to Add Support for Additional LLM Providers in AI Hedge Fund
To add support for additional LLM providers in the AI Hedge Fund framework, extend the ModelProvider enum in src/llm/models.py, import the appropriate LangChain chat wrapper, and implement a new conditional branch in the get_model factory function that handles API key validation and client instantiation.
The virattt/ai-hedge-fund repository abstracts LLM integrations behind a unified factory pattern, allowing you to add support for additional LLM providers without modifying the high-level inference logic. By updating the provider enum and model factory in src/llm/models.py, you can integrate any LangChain-compatible chat model while maintaining the existing call_llm interface used throughout the application.
Step-by-Step Implementation Guide
This guide walks through integrating a new provider (using Cohere as an example) into the codebase.
Step 1: Extend the ModelProvider Enum
Open src/llm/models.py and add a new member to the ModelProvider enumeration (around line 18). This string-based enum identifies the provider throughout the application.
# src/llm/models.py
class ModelProvider(str, Enum):
"""Enum for supported LLM providers"""
OPENAI = "OpenAI"
ANTHROPIC = "Anthropic"
GROQ = "Groq"
XAI = "xAI"
# Add your new provider below
COHERE = "Cohere"
Reference: ModelProvider enum definition
Step 2: Install and Import the LangChain Wrapper
Install the provider's LangChain integration package. For example, to add Cohere:
pip install langchain-cohere
Then import the chat wrapper at the top of src/llm/models.py (around lines 3-12):
# src/llm/models.py
from langchain_cohere import ChatCohere # New import
Reference: Existing imports in models.py
Step 3: Add the Provider Case to get_model
Locate the get_model factory function (starting around line 138) and add a new elif block that instantiates the provider's client. The function receives model_name, model_provider, and an optional api_keys dictionary.
# src/llm/models.py
def get_model(model_name: str, model_provider: ModelProvider, api_keys: dict = None):
...
elif model_provider == ModelProvider.COHERE:
api_key = (api_keys or {}).get("COHERE_API_KEY") or os.getenv("COHERE_API_KEY")
if not api_key:
print("API Key Error: Please set COHERE_API_KEY in your .env or provide via API keys.")
raise ValueError("Cohere API key not found.")
return ChatCohere(model=model_name, api_key=api_key)
...
Reference: get_model factory implementation
Step 4: Register Models in JSON (Optional)
To make the new provider's models selectable in the UI, add entries to src/llm/api_models.json (or ollama_models.json for local models). Each entry requires the provider name exactly as defined in the enum.
{
"display_name": "Cohere Command R+",
"model_name": "command-r-plus",
"provider": "Cohere"
}
Reference: api_models.json structure
Step 5: Verify UI Integration
The application builds its selection dropdowns automatically from the LLM_ORDER constant (defined around line 108 in src/llm/models.py), which derives from AVAILABLE_MODELS populated via the JSON files. No additional UI code changes are required—restart the application to see the new provider in the dropdown menus.
# src/llm/models.py
LLM_ORDER = [model.to_choice_tuple() for model in AVAILABLE_MODELS]
Reference: LLM_ORDER construction
Step 6: Test the Integration
The high-level call_llm function in src/utils/llm.py automatically routes to your new provider via get_model. Verify end-to-end functionality by invoking the model:
# src/utils/llm.py usage pattern
model_info = get_model_info(model_name, model_provider)
llm = get_model(model_name, model_provider, api_keys) # Now supports your new provider
Reference: call_llm implementation
Complete Implementation Example
Below is a unified diff showing the complete changes required to add Cohere support:
--- a/src/llm/models.py
+++ b/src/llm/models.py
@@ -1,6 +1,7 @@
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_gigachat import GigaChat
from langchain_ollama import ChatOllama
+from langchain_cohere import ChatCohere
class ModelProvider(str, Enum):
"""Enum for supported LLM providers"""
@@ -9,6 +10,7 @@
GROQ = "Groq"
XAI = "xAI"
GIGACHAT = "GigaChat"
+ COHERE = "Cohere"
def get_model(model_name: str, model_provider: ModelProvider, api_keys: dict = None):
...
@@ -20,6 +22,12 @@
if not api_key:
raise ValueError("xAI API key not found.")
return ChatXAI(model=model_name, api_key=api_key)
+ elif model_provider == ModelProvider.COHERE:
+ api_key = (api_keys or {}).get("COHERE_API_KEY") or os.getenv("COHERE_API_KEY")
+ if not api_key:
+ print("API Key Error: Please set COHERE_API_KEY in your .env or provide via API keys.")
+ raise ValueError("Cohere API key not found.")
+ return ChatCohere(model=model_name, api_key=api_key)
...
Add the corresponding entry to src/llm/api_models.json:
{
"display_name": "Cohere Command R+",
"model_name": "command-r-plus",
"provider": "Cohere"
}
Summary
-
Extend the enum: Add the provider identifier to
ModelProviderinsrc/llm/models.py. -
Import the wrapper: Install and import the LangChain chat class for your provider.
-
Implement the factory: Add an
elifbranch inget_modelthat handles API key retrieval and client instantiation. -
Register models: Add model metadata to
src/llm/api_models.jsonto populate the UI dropdowns. -
Verify integration: The
call_llmfunction insrc/utils/llm.pyautomatically utilizes the new provider without additional changes.
Frequently Asked Questions
Do I need to modify the UI code to add a new LLM provider?
No. The UI generates provider and model dropdowns automatically from the LLM_ORDER constant in src/llm/models.py, which is constructed from AVAILABLE_MODELS. As long as you register your models in src/llm/api_models.json and extend the ModelProvider enum, the new options appear after restarting the application.
What if my LLM provider doesn't have a LangChain integration?
You must implement a custom LangChain chat model class that conforms to the BaseChatModel interface. The get_model factory expects objects implementing LangChain's standard methods like invoke() and stream(). Without this compatibility layer, you cannot integrate with the existing call_llm abstraction.
How do I handle API key validation for new providers?
Follow the established pattern in src/llm/models.py: retrieve the key from the optional api_keys dictionary parameter first, then fall back to environment variables using os.getenv(). Always validate the key exists before instantiating the client, and raise a descriptive ValueError if missing to match the framework's error handling conventions.
Where should I store model metadata for local Ollama models?
Register local Ollama models in src/llm/ollama_models.json rather than api_models.json. The framework maintains separate JSON files for API-based providers and local Ollama instances, though both use the same ModelProvider.OLLAMA enum value and ChatOllama wrapper in the factory function.
Have a question about this repo?
These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:
curl -s "https://instagit.com/install.md" Maintain an open-source project? Get it listed too →