diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..ad073f7 --- /dev/null +++ b/config.yaml @@ -0,0 +1,14 @@ +database: + name: hn_embeddings + +model: + name: text-embedding-ada-002 + +paths: + pickle_cache: data/embeddings_cache.pkl + +openai: + api_key: YOUR_OPENAI_API_KEY + +rag_pipeline: + top_k: 10 diff --git a/src/connection.py b/src/connection.py index 6c910aa..0b0775a 100644 --- a/src/connection.py +++ b/src/connection.py @@ -11,11 +11,14 @@ def open_connection(dbname=None) -> duckdb.DuckDBPyConnection: Returns: `duckdb.DuckDBPyConnection`: A connection object to the local database. """ - if dbname: - return duckdb.connect(f"{dbname}.db") - else: - return duckdb.connect(":memory:") - + try: + if dbname: + return duckdb.connect(f"{dbname}.db") + else: + return duckdb.connect(":memory:") + except Exception as e: + print(f"Error connecting to the database: {e}") + raise def load_extension( con: duckdb.DuckDBPyConnection, extension: str @@ -36,7 +39,7 @@ def load_extension( try: con.install_extension(extension) con.load_extension(extension) - except: - print(f"Could not load extension {extension}") - pass + except Exception as e: + print(f"Could not load extension {extension}: {e}") + raise return con diff --git a/src/embedding.py b/src/embedding.py index df717e5..ec988cd 100644 --- a/src/embedding.py +++ b/src/embedding.py @@ -2,7 +2,7 @@ import src.operations as operations import src.openai_client as openai_client from .connection import DuckDBPyConnection - +from .operations import EmbeddingKey # Function to get embeddings, using the cache @@ -24,7 +24,7 @@ def pickle_embeddings( pickle_cache = operations.load_pickle_cache(pickle_path) for text in texts: - key = (text, model) + key = EmbeddingKey(text, model) if key not in pickle_cache: pickle_cache[key] = openai_client.create_embedding(text, model=model) embeddings.append(pickle_cache[key]) @@ -49,12 +49,13 @@ def duckdb_embeddings( """ embeddings = [] for text in texts: + key = EmbeddingKey(text, model) # check to see if embedding is in duckdb table - result = operations.is_key_in_table(con, (text, model)) + result = operations.is_key_in_table(con, key) if result: print("Embedding found in table") # if so, get it - embedding = operations.get_embedding_from_table(con, text, model) + embedding = operations.get_embedding_from_table(con, key) embeddings.append(embedding) else: print("Embedding not found in table") @@ -62,7 +63,7 @@ def duckdb_embeddings( # if not, create it embedding = openai_client.create_embedding(text, model) # and write it to the table - operations.write_embedding_to_table(con, text, model, embedding) + operations.write_embedding_to_table(con, key, embedding) embeddings.append(embedding) return embeddings diff --git a/src/embedding_operations.py b/src/embedding_operations.py new file mode 100644 index 0000000..5fae96c --- /dev/null +++ b/src/embedding_operations.py @@ -0,0 +1,66 @@ +from typing import List, Tuple +from .connection import DuckDBPyConnection +from .operations import EmbeddingKey, load_pickle_cache, save_pickle_cache, is_key_in_table, get_embedding_from_table, write_embedding_to_table +import src.openai_client as openai_client + +class EmbeddingOperations: + def __init__(self, con: DuckDBPyConnection, pickle_path: str): + self.con = con + self.pickle_path = pickle_path + + def pickle_embeddings(self, texts: List[str], model: str) -> List[List[float]]: + embeddings = [] + pickle_cache = load_pickle_cache(self.pickle_path) + + for text in texts: + key = EmbeddingKey(text, model) + if key not in pickle_cache: + pickle_cache[key] = openai_client.create_embedding(text, model=model) + embeddings.append(pickle_cache[key]) + save_pickle_cache(pickle_cache, self.pickle_path) + return embeddings + + def duckdb_embeddings(self, texts: List[str], model: str) -> List[List[float]]: + embeddings = [] + for text in texts: + key = EmbeddingKey(text, model) + result = is_key_in_table(self.con, key) + if result: + embedding = get_embedding_from_table(self.con, key) + embeddings.append(embedding) + else: + embedding = openai_client.create_embedding(text, model) + write_embedding_to_table(self.con, key, embedding) + embeddings.append(embedding) + return embeddings + + def cosine_similarity(self, l1, l2) -> float: + return self.con.execute(f"SELECT list_cosine_similarity({l1}, {l2})").fetchall()[0][0] + + def get_similarity(self, text: str, model: str) -> list[tuple[str, float]]: + sql = """ + WITH q1 AS ( + SELECT + ? as text, + ?::DOUBLE[] AS embedding + ), + + q2 AS ( + select + distinct text, + embedding::DOUBLE[] as embedding + from embeddings + ) + + SELECT + b.text, + list_cosine_similarity(a.embedding::DOUBLE[], b.embedding::DOUBLE[]) AS similarity + FROM q1 a + join q2 b on a.text != b.text + ORDER BY similarity DESC + LIMIT 10 + """ + + embedding = self.duckdb_embeddings([text], model)[0] + result = self.con.execute(sql, [text, embedding]).fetchall() + return result diff --git a/src/openai_client.py b/src/openai_client.py index de38899..0ccd77a 100644 --- a/src/openai_client.py +++ b/src/openai_client.py @@ -13,6 +13,8 @@ def get_openai_client() -> OpenAI: :return: An instance of the OpenAI client. """ key = os.getenv("OPENAI_API_KEY") + if not key: + raise ValueError("OpenAI API key not found in environment variables.") client = OpenAI(api_key=key) return client @@ -34,9 +36,13 @@ def create_embedding( try: client = get_openai_client() except Exception as e: - print(e) + print(f"Error initializing OpenAI client: {e}") return [] text = text.replace("\n", " ") - response = client.embeddings.create(input=[text], model=model, **kwargs) - return response.data[0].embedding + try: + response = client.embeddings.create(input=[text], model=model, **kwargs) + return response.data[0].embedding + except Exception as e: + print(f"Error creating embedding: {e}") + return [] diff --git a/src/operations.py b/src/operations.py index d7ef73c..22b305d 100644 --- a/src/operations.py +++ b/src/operations.py @@ -9,23 +9,44 @@ PickleCache = Dict[Tuple[str, str], List[float]] +class EmbeddingKey: + def __init__(self, text: str, model: str): + self._text = text + self._model = model + + @property + def text(self): + return self._text + + @property + def model(self): + return self._model + + def __eq__(self, other): + if isinstance(other, EmbeddingKey): + return self.text == other.text and self.model == other.model + return False + + def __hash__(self): + return hash((self.text, self.model)) + + def write_embedding_to_table( - con: DuckDBPyConnection, text: str, model: str, embedding: List[float] + con: DuckDBPyConnection, key: EmbeddingKey, embedding: List[float] ) -> DuckDBPyConnection: """ Writes the given embedding to the `embeddings` table in the database. Args: con (DuckDBPyConnection): The connection to the DuckDB database. - text (str): The text associated with the embedding. - model (str): The model used to generate the embedding. + key (EmbeddingKey): The key associated with the embedding. embedding (List[float]): The embedding vector. Returns: DuckDBPyConnection: The connection to the DuckDB database after the insertion. """ create_table_if_not_exists(con) - con.execute("INSERT INTO embeddings VALUES (?, ?, ?)", [text, model, embedding]) + con.execute("INSERT INTO embeddings VALUES (?, ?, ?)", [key.text, key.model, embedding]) return con @@ -44,13 +65,13 @@ def create_table_if_not_exists(con) -> None: ) -def is_key_in_table(con: DuckDBPyConnection, key: Tuple[str, str]) -> bool: +def is_key_in_table(con: DuckDBPyConnection, key: EmbeddingKey) -> bool: """ Check if a key exists in the embeddings table. Args: con (DuckDBPyConnection): The connection to the DuckDB database. - key (Tuple[str, str]): The key to check in the format (text, model). + key (EmbeddingKey): The key to check. Returns: bool: True if the key exists in the table, False otherwise. @@ -58,7 +79,7 @@ def is_key_in_table(con: DuckDBPyConnection, key: Tuple[str, str]) -> bool: create_table_if_not_exists(con) result = con.execute( "SELECT EXISTS(SELECT * FROM embeddings WHERE text=? AND model=?)", - [key[0], key[1]], + [key.text, key.model], ).fetchone() if result: return result[0] @@ -66,17 +87,17 @@ def is_key_in_table(con: DuckDBPyConnection, key: Tuple[str, str]) -> bool: def list_keys_in_table( - con: DuckDBPyConnection, keys: List[Tuple[str, str]] -) -> list[tuple[str, str]]: + con: DuckDBPyConnection, keys: List[EmbeddingKey] +) -> list[EmbeddingKey]: """ Returns a list of keys that exist in the specified table. Args: con (DuckDBPyConnection): The connection to the DuckDB database. - keys (List[Tuple[str, str]]): The keys to check in the table. + keys (List[EmbeddingKey]): The keys to check in the table. Returns: - List[Tuple[str, str]]: A list of keys that exist in the table. + List[EmbeddingKey]: A list of keys that exist in the table. """ keys_in_table = [] @@ -117,7 +138,8 @@ def write_pickle_cache_to_duckdb(con: DuckDBPyConnection, pickle_path: str) -> N cache = load_pickle_cache(pickle_path) create_table_if_not_exists(con) for key, value in cache.items(): - write_embedding_to_table(con, key[0], key[1], value) + embedding_key = EmbeddingKey(key[0], key[1]) + write_embedding_to_table(con, embedding_key, value) # Function to save the cache to a file @@ -136,24 +158,23 @@ def save_pickle_cache(cache: PickleCache, cache_path: str) -> None: pickle.dump(cache, file) -def get_embedding_from_table(con: DuckDBPyConnection, text: str, model: str) -> List[float]: +def get_embedding_from_table(con: DuckDBPyConnection, key: EmbeddingKey) -> List[float]: """ - Retrieves the embedding from the 'embeddings' table based on the given text and model. + Retrieves the embedding from the 'embeddings' table based on the given key. Args: con (DuckDBPyConnection): The connection to the DuckDB database. - text (str): The text to search for in the 'text' column of the table. - model (str): The model to search for in the 'model' column of the table. + key (EmbeddingKey): The key to search for in the table. Returns: - List[float]: The embedding associated with the given text and model. + List[float]: The embedding associated with the given key. Raises: - ValueError: If the embedding for the given text and model is not found in the table. + ValueError: If the embedding for the given key is not found in the table. """ result = con.execute( - "SELECT embedding FROM embeddings WHERE text=? AND model=?", [text, model] + "SELECT embedding FROM embeddings WHERE text=? AND model=?", [key.text, key.model] ).fetchone() if result: return result[0] - raise ValueError(f"Embedding for {text} with model {model} not found in table") + raise ValueError(f"Embedding for {key.text} with model {key.model} not found in table") diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..4e556f5 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,27 @@ +import os +import pytest +import yaml +from dotenv import load_dotenv + +load_dotenv() + +@pytest.fixture +def config_data(): + with open("config.yaml", "r") as file: + return yaml.safe_load(file) + +def test_database_name(config_data): + assert config_data["database"]["name"] == "hn_embeddings" + +def test_model_name(config_data): + assert config_data["model"]["name"] == "text-embedding-ada-002" + +def test_pickle_cache_path(config_data): + assert config_data["paths"]["pickle_cache"] == "data/embeddings_cache.pkl" + +def test_openai_api_key(): + api_key = os.getenv("OPENAI_API_KEY") + assert api_key == "YOUR_OPENAI_API_KEY" + +def test_rag_pipeline_top_k(config_data): + assert config_data["rag_pipeline"]["top_k"] == 10 diff --git a/tests/test_embedding_operations.py b/tests/test_embedding_operations.py new file mode 100644 index 0000000..c4b66d3 --- /dev/null +++ b/tests/test_embedding_operations.py @@ -0,0 +1,44 @@ +import pytest +from src.embedding_operations import EmbeddingOperations +from src.connection import open_connection +from src.operations import EmbeddingKey + +@pytest.fixture +def setup_db(): + con = open_connection(":memory:") + con.execute("CREATE TABLE embeddings (text VARCHAR, model VARCHAR, embedding DOUBLE[])") + yield con + con.close() + +def test_pickle_embeddings(setup_db): + con = setup_db + embedding_ops = EmbeddingOperations(con, "test_cache.pkl") + texts = ["test text 1", "test text 2"] + model = "test-model" + embeddings = embedding_ops.pickle_embeddings(texts, model) + assert len(embeddings) == 2 + +def test_duckdb_embeddings(setup_db): + con = setup_db + embedding_ops = EmbeddingOperations(con, "test_cache.pkl") + texts = ["test text 1", "test text 2"] + model = "test-model" + embeddings = embedding_ops.duckdb_embeddings(texts, model) + assert len(embeddings) == 2 + +def test_cosine_similarity(setup_db): + con = setup_db + embedding_ops = EmbeddingOperations(con, "test_cache.pkl") + l1 = [1.0, 2.0, 3.0] + l2 = [1.0, 2.0, 3.0] + similarity = embedding_ops.cosine_similarity(l1, l2) + assert similarity == 1.0 + +def test_get_similarity(setup_db): + con = setup_db + embedding_ops = EmbeddingOperations(con, "test_cache.pkl") + text = "test text" + model = "test-model" + con.execute("INSERT INTO embeddings VALUES (?, ?, ?)", [text, model, [1.0, 2.0, 3.0]]) + result = embedding_ops.get_similarity(text, model) + assert len(result) == 0