From f21b6e7d25e4f87269129ee12c4e3fb662946326 Mon Sep 17 00:00:00 2001
From: Connor <36115510+ScarFX@users.noreply.github.com>
Date: Fri, 27 Sep 2024 13:34:24 -0400
Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=A8=20feat:=20AWS=20Bedrock=20embeddin?=
 =?UTF-8?q?gs=20support=20=20(#75)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* "WIP: adding bedrock embeddings"

* WIP: feat bedrock embeddings support

* feat: aws bedrock embeddings support

* refactor: update aws region var name

* docs: update env variables documentation for bedrock

* docs: add bedrock embeddings provider in list
---
 README.md        |  6 +++++-
 config.py        | 23 +++++++++++++++++++++++
 requirements.txt |  2 ++
 3 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/README.md b/README.md
index 3146fe3..37bdcdf 100644
--- a/README.md
+++ b/README.md
@@ -61,7 +61,7 @@ The following environment variables are required to run the application:
 - `PDF_EXTRACT_IMAGES`: (Optional) A boolean value indicating whether to extract images from PDF files. Default value is "False".
 - `DEBUG_RAG_API`: (Optional) Set to "True" to show more verbose logging output in the server console, and to enable postgresql database routes
 - `CONSOLE_JSON`: (Optional) Set to "True" to log as json for Cloud Logging aggregations
-- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "azure", "huggingface", "huggingfacetei" or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai"
+- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "bedrock", "azure", "huggingface", "huggingfacetei" or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai"
 - `EMBEDDINGS_MODEL`: (Optional) Set a valid embeddings model to use from the configured provider.
     - **Defaults**
     - openai: "text-embedding-3-small"
@@ -69,6 +69,7 @@ The following environment variables are required to run the application:
     - huggingface: "sentence-transformers/all-MiniLM-L6-v2"
     - huggingfacetei: "http://huggingfacetei:3000". Hugging Face TEI uses model defined on TEI service launch.
     - ollama: "nomic-embed-text"
+    - bedrock: "amazon.titan-embed-text-v1"
 - `RAG_AZURE_OPENAI_API_VERSION`: (Optional) Default is `2023-05-15`. The version of the Azure OpenAI API.
 - `RAG_AZURE_OPENAI_API_KEY`: (Optional) The API key for Azure OpenAI service.
     - Note: `AZURE_OPENAI_API_KEY` will work but `RAG_AZURE_OPENAI_API_KEY` will override it in order to not conflict with LibreChat setting.
@@ -79,6 +80,9 @@ The following environment variables are required to run the application:
 - `OLLAMA_BASE_URL`: (Optional) defaults to `http://ollama:11434`.
 - `ATLAS_SEARCH_INDEX`: (Optional) the name of the vector search index if using Atlas MongoDB, defaults to `vector_index`
 - `MONGO_VECTOR_COLLECTION`: Deprecated for MongoDB, please use `ATLAS_SEARCH_INDEX` and `COLLECTION_NAME`
+- `AWS_DEFAULT_REGION`: (Optional) defaults to `us-east-1`
+- `AWS_ACCESS_KEY_ID`: (Optional) needed for bedrock embeddings
+- `AWS_SECRET_ACCESS_KEY`: (Optional) needed for bedrock embeddings
 
 Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.
 
diff --git a/config.py b/config.py
index 36f9202..d8d65ba 100644
--- a/config.py
+++ b/config.py
@@ -2,11 +2,13 @@
 import os
 import json
 import logging
+import boto3
 from enum import Enum
 from datetime import datetime
 from dotenv import find_dotenv, load_dotenv
 from langchain_ollama import OllamaEmbeddings
 from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings
+from langchain_aws import BedrockEmbeddings
 from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
 from starlette.middleware.base import BaseHTTPMiddleware
 from store_factory import get_vector_store
@@ -25,6 +27,7 @@ class EmbeddingsProvider(Enum):
     HUGGINGFACE = "huggingface"
     HUGGINGFACETEI = "huggingfacetei"
     OLLAMA = "ollama"
+    BEDROCK = "bedrock"
 
 
 def get_env_variable(
@@ -168,6 +171,8 @@ RAG_AZURE_OPENAI_ENDPOINT = get_env_variable(
 ).rstrip("/")
 HF_TOKEN = get_env_variable("HF_TOKEN", "")
 OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")
+AWS_ACCESS_KEY_ID = get_env_variable("AWS_ACCESS_KEY_ID", "")
+AWS_SECRET_ACCESS_KEY = get_env_variable("AWS_SECRET_ACCESS_KEY", "")
 
 ## Embeddings
 
@@ -195,6 +200,17 @@ def init_embeddings(provider, model):
         return HuggingFaceEndpointEmbeddings(model=model)
     elif provider == EmbeddingsProvider.OLLAMA:
         return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL)
+    elif provider == EmbeddingsProvider.BEDROCK:
+        session = boto3.Session(
+            aws_access_key_id=AWS_ACCESS_KEY_ID,
+            aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
+            region_name=AWS_DEFAULT_REGION,
+        )
+        return BedrockEmbeddings(
+            client=session.client("bedrock-runtime"),
+            model_id=model,
+            region_name=AWS_DEFAULT_REGION,
+        )
     else:
         raise ValueError(f"Unsupported embeddings provider: {provider}")
 
@@ -217,6 +233,13 @@ elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACETEI:
     )
 elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA:
     EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text")
+elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.BEDROCK:
+    EMBEDDINGS_MODEL = get_env_variable(
+        "EMBEDDINGS_MODEL", "amazon.titan-embed-text-v1"
+    )
+    AWS_DEFAULT_REGION = get_env_variable(
+        "AWS_DEFAULT_REGION", "us-east-1"
+    )
 else:
     raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}")
 
diff --git a/requirements.txt b/requirements.txt
index f015b94..28f18a0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,6 +2,8 @@ langchain==0.3
 langchain_community==0.3
 langchain_openai==0.2.0
 langchain_core==0.3.5
+langchain-aws==0.2.1
+boto3==1.34.144
 sqlalchemy==2.0.28
 python-dotenv==1.0.1
 fastapi==0.110.0
-- 
GitLab