Skip to content
Snippets Groups Projects
Unverified Commit f21b6e7d authored by Connor's avatar Connor Committed by GitHub
Browse files

:rock: feat: AWS Bedrock embeddings support (#75)

* "WIP: adding bedrock embeddings"

* WIP: feat bedrock embeddings support

* feat: aws bedrock embeddings support

* refactor: update aws region var name

* docs: update env variables documentation for bedrock

* docs: add bedrock embeddings provider in list
parent edd8a0c0
No related branches found
No related tags found
No related merge requests found
...@@ -61,7 +61,7 @@ The following environment variables are required to run the application: ...@@ -61,7 +61,7 @@ The following environment variables are required to run the application:
- `PDF_EXTRACT_IMAGES`: (Optional) A boolean value indicating whether to extract images from PDF files. Default value is "False". - `PDF_EXTRACT_IMAGES`: (Optional) A boolean value indicating whether to extract images from PDF files. Default value is "False".
- `DEBUG_RAG_API`: (Optional) Set to "True" to show more verbose logging output in the server console, and to enable postgresql database routes - `DEBUG_RAG_API`: (Optional) Set to "True" to show more verbose logging output in the server console, and to enable postgresql database routes
- `CONSOLE_JSON`: (Optional) Set to "True" to log as json for Cloud Logging aggregations - `CONSOLE_JSON`: (Optional) Set to "True" to log as json for Cloud Logging aggregations
- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "azure", "huggingface", "huggingfacetei" or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai" - `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "bedrock", "azure", "huggingface", "huggingfacetei" or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai"
- `EMBEDDINGS_MODEL`: (Optional) Set a valid embeddings model to use from the configured provider. - `EMBEDDINGS_MODEL`: (Optional) Set a valid embeddings model to use from the configured provider.
- **Defaults** - **Defaults**
- openai: "text-embedding-3-small" - openai: "text-embedding-3-small"
...@@ -69,6 +69,7 @@ The following environment variables are required to run the application: ...@@ -69,6 +69,7 @@ The following environment variables are required to run the application:
- huggingface: "sentence-transformers/all-MiniLM-L6-v2" - huggingface: "sentence-transformers/all-MiniLM-L6-v2"
- huggingfacetei: "http://huggingfacetei:3000". Hugging Face TEI uses model defined on TEI service launch. - huggingfacetei: "http://huggingfacetei:3000". Hugging Face TEI uses model defined on TEI service launch.
- ollama: "nomic-embed-text" - ollama: "nomic-embed-text"
- bedrock: "amazon.titan-embed-text-v1"
- `RAG_AZURE_OPENAI_API_VERSION`: (Optional) Default is `2023-05-15`. The version of the Azure OpenAI API. - `RAG_AZURE_OPENAI_API_VERSION`: (Optional) Default is `2023-05-15`. The version of the Azure OpenAI API.
- `RAG_AZURE_OPENAI_API_KEY`: (Optional) The API key for Azure OpenAI service. - `RAG_AZURE_OPENAI_API_KEY`: (Optional) The API key for Azure OpenAI service.
- Note: `AZURE_OPENAI_API_KEY` will work but `RAG_AZURE_OPENAI_API_KEY` will override it in order to not conflict with LibreChat setting. - Note: `AZURE_OPENAI_API_KEY` will work but `RAG_AZURE_OPENAI_API_KEY` will override it in order to not conflict with LibreChat setting.
...@@ -79,6 +80,9 @@ The following environment variables are required to run the application: ...@@ -79,6 +80,9 @@ The following environment variables are required to run the application:
- `OLLAMA_BASE_URL`: (Optional) defaults to `http://ollama:11434`. - `OLLAMA_BASE_URL`: (Optional) defaults to `http://ollama:11434`.
- `ATLAS_SEARCH_INDEX`: (Optional) the name of the vector search index if using Atlas MongoDB, defaults to `vector_index` - `ATLAS_SEARCH_INDEX`: (Optional) the name of the vector search index if using Atlas MongoDB, defaults to `vector_index`
- `MONGO_VECTOR_COLLECTION`: Deprecated for MongoDB, please use `ATLAS_SEARCH_INDEX` and `COLLECTION_NAME` - `MONGO_VECTOR_COLLECTION`: Deprecated for MongoDB, please use `ATLAS_SEARCH_INDEX` and `COLLECTION_NAME`
- `AWS_DEFAULT_REGION`: (Optional) defaults to `us-east-1`
- `AWS_ACCESS_KEY_ID`: (Optional) needed for bedrock embeddings
- `AWS_SECRET_ACCESS_KEY`: (Optional) needed for bedrock embeddings
Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables. Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.
......
...@@ -2,11 +2,13 @@ ...@@ -2,11 +2,13 @@
import os import os
import json import json
import logging import logging
import boto3
from enum import Enum from enum import Enum
from datetime import datetime from datetime import datetime
from dotenv import find_dotenv, load_dotenv from dotenv import find_dotenv, load_dotenv
from langchain_ollama import OllamaEmbeddings from langchain_ollama import OllamaEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings
from langchain_aws import BedrockEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from store_factory import get_vector_store from store_factory import get_vector_store
...@@ -25,6 +27,7 @@ class EmbeddingsProvider(Enum): ...@@ -25,6 +27,7 @@ class EmbeddingsProvider(Enum):
HUGGINGFACE = "huggingface" HUGGINGFACE = "huggingface"
HUGGINGFACETEI = "huggingfacetei" HUGGINGFACETEI = "huggingfacetei"
OLLAMA = "ollama" OLLAMA = "ollama"
BEDROCK = "bedrock"
def get_env_variable( def get_env_variable(
...@@ -168,6 +171,8 @@ RAG_AZURE_OPENAI_ENDPOINT = get_env_variable( ...@@ -168,6 +171,8 @@ RAG_AZURE_OPENAI_ENDPOINT = get_env_variable(
).rstrip("/") ).rstrip("/")
HF_TOKEN = get_env_variable("HF_TOKEN", "") HF_TOKEN = get_env_variable("HF_TOKEN", "")
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434") OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")
AWS_ACCESS_KEY_ID = get_env_variable("AWS_ACCESS_KEY_ID", "")
AWS_SECRET_ACCESS_KEY = get_env_variable("AWS_SECRET_ACCESS_KEY", "")
## Embeddings ## Embeddings
...@@ -195,6 +200,17 @@ def init_embeddings(provider, model): ...@@ -195,6 +200,17 @@ def init_embeddings(provider, model):
return HuggingFaceEndpointEmbeddings(model=model) return HuggingFaceEndpointEmbeddings(model=model)
elif provider == EmbeddingsProvider.OLLAMA: elif provider == EmbeddingsProvider.OLLAMA:
return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL) return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL)
elif provider == EmbeddingsProvider.BEDROCK:
session = boto3.Session(
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
region_name=AWS_DEFAULT_REGION,
)
return BedrockEmbeddings(
client=session.client("bedrock-runtime"),
model_id=model,
region_name=AWS_DEFAULT_REGION,
)
else: else:
raise ValueError(f"Unsupported embeddings provider: {provider}") raise ValueError(f"Unsupported embeddings provider: {provider}")
...@@ -217,6 +233,13 @@ elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACETEI: ...@@ -217,6 +233,13 @@ elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACETEI:
) )
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA: elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA:
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text") EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text")
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.BEDROCK:
EMBEDDINGS_MODEL = get_env_variable(
"EMBEDDINGS_MODEL", "amazon.titan-embed-text-v1"
)
AWS_DEFAULT_REGION = get_env_variable(
"AWS_DEFAULT_REGION", "us-east-1"
)
else: else:
raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}") raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}")
......
...@@ -2,6 +2,8 @@ langchain==0.3 ...@@ -2,6 +2,8 @@ langchain==0.3
langchain_community==0.3 langchain_community==0.3
langchain_openai==0.2.0 langchain_openai==0.2.0
langchain_core==0.3.5 langchain_core==0.3.5
langchain-aws==0.2.1
boto3==1.34.144
sqlalchemy==2.0.28 sqlalchemy==2.0.28
python-dotenv==1.0.1 python-dotenv==1.0.1
fastapi==0.110.0 fastapi==0.110.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment