From f21b6e7d25e4f87269129ee12c4e3fb662946326 Mon Sep 17 00:00:00 2001
From: Connor <36115510+ScarFX@users.noreply.github.com>
Date: Fri, 27 Sep 2024 13:34:24 -0400
Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=A8=20feat:=20AWS=20Bedrock=20embeddin?=
=?UTF-8?q?gs=20support=20=20(#75)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* "WIP: adding bedrock embeddings"
* WIP: feat bedrock embeddings support
* feat: aws bedrock embeddings support
* refactor: update aws region var name
* docs: update env variables documentation for bedrock
* docs: add bedrock embeddings provider in list
---
README.md | 6 +++++-
config.py | 23 +++++++++++++++++++++++
requirements.txt | 2 ++
3 files changed, 30 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 3146fe3..37bdcdf 100644
--- a/README.md
+++ b/README.md
@@ -61,7 +61,7 @@ The following environment variables are required to run the application:
- `PDF_EXTRACT_IMAGES`: (Optional) A boolean value indicating whether to extract images from PDF files. Default value is "False".
- `DEBUG_RAG_API`: (Optional) Set to "True" to show more verbose logging output in the server console, and to enable postgresql database routes
- `CONSOLE_JSON`: (Optional) Set to "True" to log as json for Cloud Logging aggregations
-- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "azure", "huggingface", "huggingfacetei" or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai"
+- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "bedrock", "azure", "huggingface", "huggingfacetei" or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai"
- `EMBEDDINGS_MODEL`: (Optional) Set a valid embeddings model to use from the configured provider.
- **Defaults**
- openai: "text-embedding-3-small"
@@ -69,6 +69,7 @@ The following environment variables are required to run the application:
- huggingface: "sentence-transformers/all-MiniLM-L6-v2"
- huggingfacetei: "http://huggingfacetei:3000". Hugging Face TEI uses model defined on TEI service launch.
- ollama: "nomic-embed-text"
+ - bedrock: "amazon.titan-embed-text-v1"
- `RAG_AZURE_OPENAI_API_VERSION`: (Optional) Default is `2023-05-15`. The version of the Azure OpenAI API.
- `RAG_AZURE_OPENAI_API_KEY`: (Optional) The API key for Azure OpenAI service.
- Note: `AZURE_OPENAI_API_KEY` will work but `RAG_AZURE_OPENAI_API_KEY` will override it in order to not conflict with LibreChat setting.
@@ -79,6 +80,9 @@ The following environment variables are required to run the application:
- `OLLAMA_BASE_URL`: (Optional) defaults to `http://ollama:11434`.
- `ATLAS_SEARCH_INDEX`: (Optional) the name of the vector search index if using Atlas MongoDB, defaults to `vector_index`
- `MONGO_VECTOR_COLLECTION`: Deprecated for MongoDB, please use `ATLAS_SEARCH_INDEX` and `COLLECTION_NAME`
+- `AWS_DEFAULT_REGION`: (Optional) defaults to `us-east-1`
+- `AWS_ACCESS_KEY_ID`: (Optional) needed for bedrock embeddings
+- `AWS_SECRET_ACCESS_KEY`: (Optional) needed for bedrock embeddings
Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.
diff --git a/config.py b/config.py
index 36f9202..d8d65ba 100644
--- a/config.py
+++ b/config.py
@@ -2,11 +2,13 @@
import os
import json
import logging
+import boto3
from enum import Enum
from datetime import datetime
from dotenv import find_dotenv, load_dotenv
from langchain_ollama import OllamaEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings
+from langchain_aws import BedrockEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from starlette.middleware.base import BaseHTTPMiddleware
from store_factory import get_vector_store
@@ -25,6 +27,7 @@ class EmbeddingsProvider(Enum):
HUGGINGFACE = "huggingface"
HUGGINGFACETEI = "huggingfacetei"
OLLAMA = "ollama"
+ BEDROCK = "bedrock"
def get_env_variable(
@@ -168,6 +171,8 @@ RAG_AZURE_OPENAI_ENDPOINT = get_env_variable(
).rstrip("/")
HF_TOKEN = get_env_variable("HF_TOKEN", "")
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")
+AWS_ACCESS_KEY_ID = get_env_variable("AWS_ACCESS_KEY_ID", "")
+AWS_SECRET_ACCESS_KEY = get_env_variable("AWS_SECRET_ACCESS_KEY", "")
## Embeddings
@@ -195,6 +200,17 @@ def init_embeddings(provider, model):
return HuggingFaceEndpointEmbeddings(model=model)
elif provider == EmbeddingsProvider.OLLAMA:
return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL)
+ elif provider == EmbeddingsProvider.BEDROCK:
+ session = boto3.Session(
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
+ region_name=AWS_DEFAULT_REGION,
+ )
+ return BedrockEmbeddings(
+ client=session.client("bedrock-runtime"),
+ model_id=model,
+ region_name=AWS_DEFAULT_REGION,
+ )
else:
raise ValueError(f"Unsupported embeddings provider: {provider}")
@@ -217,6 +233,13 @@ elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACETEI:
)
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA:
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text")
+elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.BEDROCK:
+ EMBEDDINGS_MODEL = get_env_variable(
+ "EMBEDDINGS_MODEL", "amazon.titan-embed-text-v1"
+ )
+ AWS_DEFAULT_REGION = get_env_variable(
+ "AWS_DEFAULT_REGION", "us-east-1"
+ )
else:
raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}")
diff --git a/requirements.txt b/requirements.txt
index f015b94..28f18a0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,6 +2,8 @@ langchain==0.3
langchain_community==0.3
langchain_openai==0.2.0
langchain_core==0.3.5
+langchain-aws==0.2.1
+boto3==1.34.144
sqlalchemy==2.0.28
python-dotenv==1.0.1
fastapi==0.110.0
--
GitLab