From 3e4430dfb66dd2a1339fa9aa32ec50b97ee5293e Mon Sep 17 00:00:00 2001 From: Pritam Das <69068731+Pritam3355@users.noreply.github.com> Date: Sun, 20 Oct 2024 13:14:28 +0530 Subject: [PATCH] Add files via upload --- neural_network/chatbot/chat_db.py | 490 ++++++++++++++++-------------- 1 file changed, 266 insertions(+), 224 deletions(-) diff --git a/neural_network/chatbot/chat_db.py b/neural_network/chatbot/chat_db.py index 5f2751294..3dce67120 100644 --- a/neural_network/chatbot/chat_db.py +++ b/neural_network/chatbot/chat_db.py @@ -1,14 +1,12 @@ -""" -credits : https://medium.com/google-developer-experts/beyond-live-sessions-building-persistent-memory-chatbots-with-langchain-gemini-pro-and-firebase-19d6f84e21d3 - -""" - import os import datetime -from dotenv import load_dotenv import mysql.connector +from dotenv import load_dotenv from together import Together from groq import Groq +import unittest +from unittest.mock import patch +from io import StringIO load_dotenv() @@ -20,227 +18,271 @@ db_config = { "database": os.environ.get("DB_NAME"), } -api_service = os.environ.get("API_SERVICE") - - -def create_tables() -> None: - """ - Create the ChatDB.Chat_history and ChatDB.Chat_data tables - if they do not exist.Also, create a trigger to update is_stream - in Chat_data when Chat_history.is_stream is updated. - """ - try: - conn = mysql.connector.connect(**db_config) - cursor = conn.cursor() - - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS ChatDB.Chat_history ( - chat_id INT AUTO_INCREMENT PRIMARY KEY, - start_time DATETIME, - is_stream INT - ) - """ - ) - - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS ChatDB.Chat_data ( - id INT AUTO_INCREMENT PRIMARY KEY, - chat_id INT, - user TEXT, - assistant TEXT, - FOREIGN KEY (chat_id) REFERENCES ChatDB.Chat_history(chat_id) - ) - """ - ) - - cursor.execute("DROP TRIGGER IF EXISTS update_is_stream;") - - cursor.execute( - """ - CREATE TRIGGER update_is_stream - AFTER UPDATE ON ChatDB.Chat_history - FOR EACH ROW - BEGIN - UPDATE ChatDB.Chat_data - SET is_stream = NEW.is_stream - WHERE chat_id = NEW.chat_id; - END; - """ - ) - - conn.commit() - except mysql.connector.Error as err: - print(f"Error: {err}") - finally: - cursor.close() - conn.close() - print("Tables and trigger created successfully") - - -def insert_chat_history(start_time: datetime.datetime, is_stream: int) -> None: - """ - Insert a new row into the ChatDB.Chat_history table. - :param start_time: Timestamp of when the chat started - :param is_stream: Indicator of whether the conversation is - ongoing, starting, or ending - """ - try: - conn = mysql.connector.connect(**db_config) - cursor = conn.cursor() - cursor.execute( - """ - INSERT INTO ChatDB.Chat_history (start_time, is_stream) - VALUES (%s, %s) - """, - (start_time, is_stream), - ) - conn.commit() - except mysql.connector.Error as err: - print(f"Error: {err}") - finally: - cursor.close() - conn.close() - - -def get_latest_chat_id() -> int: - """ - Retrieve the latest chat_id from the ChatDB.Chat_history table. - :return: The latest chat_id or None if no chat_id exists. - """ - try: - conn = mysql.connector.connect(**db_config) - cursor = conn.cursor() - cursor.execute( - """ - SELECT chat_id FROM ChatDB.Chat_history - ORDER BY chat_id DESC LIMIT 1 - """ - ) - chat_id = cursor.fetchone()[0] - return chat_id if chat_id else None - except mysql.connector.Error as err: - print(f"Error: {err}") - return 0 - finally: - cursor.close() - conn.close() - - -def insert_chat_data(chat_id: int, user_message: str, assistant_message: str) -> None: - """ - Insert a new row into the ChatDB.Chat_data table. - :param chat_id: The ID of the chat session - :param user_message: The user's message - :param assistant_message: The assistant's message - """ - try: - conn = mysql.connector.connect(**db_config) - cursor = conn.cursor() - cursor.execute( - """ - INSERT INTO ChatDB.Chat_data (chat_id, user, assistant) - VALUES (%s, %s, %s) - """, - (chat_id, user_message, assistant_message), - ) - conn.commit() - except mysql.connector.Error as err: - print(f"Error: {err}") - finally: - cursor.close() - conn.close() - - -def generate_llm_response( - conversation_history: list[dict], api_service: str = "Groq" -) -> str: - """ - Generate a response from the LLM based on the conversation history. - :param conversation_history: List of dictionaries representing - the conversation so far - :param api_service: Choose between "Together" or "Groq" as the - API service - :return: Assistant's response as a string - """ - bot_response = "" - if api_service == "Together": - client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) - response = client.chat.completions.create( - model="meta-llama/Llama-3.2-3B-Instruct-Turbo", - messages=conversation_history, - max_tokens=512, - temperature=0.3, - top_p=0.7, - top_k=50, - repetition_penalty=1, - stop=["<|eot_id|>", "<|eom_id|>"], - stream=False, - ) - bot_response = response.choices[0].message.content - else: - client = Groq(api_key=os.environ.get("GROQ_API_KEY")) - response = client.chat.completions.create( - model="llama3-8b-8192", - messages=conversation_history, - max_tokens=1024, - temperature=0.3, - top_p=0.7, - stop=["<|eot_id|>", "<|eom_id|>"], - stream=False, - ) - bot_response = response.choices[0].message.content - - return bot_response - - -def chat_session() -> None: - """ - Start a chatbot session, allowing the user to interact with the LLM. - Saves conversation history in the database and ends the session on "/stop" command. - """ - print("Welcome to the chatbot! Type '/stop' to end the conversation.") - - conversation_history = [] - start_time = datetime.datetime.now(datetime.timezone.utc) - - chat_id_pk = None - api_service = "Groq" # or "Together" - - while True: - user_input = input("\nYou: ").strip() - conversation_history.append({"role": "user", "content": user_input}) - - if chat_id_pk is None: - if user_input.lower() == "/stop": - break - bot_response = generate_llm_response(conversation_history, api_service) - conversation_history.append({"role": "assistant", "content": bot_response}) - - is_stream = 1 # New conversation - insert_chat_history(start_time, is_stream) - chat_id_pk = get_latest_chat_id() - insert_chat_data(chat_id_pk, user_input, bot_response) +class LLMService: + def __init__(self, api_service: str): + self.api_service = api_service + if self.api_service == "Together": + self.client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) else: - if user_input.lower() == "/stop": - is_stream = 2 # End of conversation + self.client = Groq(api_key=os.environ.get("GROQ_API_KEY")) + + def generate_response(self, conversation_history: list[dict]) -> str: + """ + Generate a response from the LLM based on the conversation history. + + Example: + >>> llm_service = LLMService(api_service="Groq") + >>> response = llm_service.generate_response([{"role": "user", "content": "Hello"}]) + >>> isinstance(response, str) + True + """ + if self.api_service == "Together": + response = self.client.chat.completions.create( + model="meta-llama/Llama-3.2-3B-Instruct-Turbo", + messages=conversation_history, + max_tokens=512, + temperature=0.3, + top_p=0.7, + top_k=50, + repetition_penalty=1, + stop=["<|eot_id|>", "<|eom_id|>"], + stream=False, + ) + else: + response = self.client.chat.completions.create( + model="llama3-8b-8192", + messages=conversation_history, + max_tokens=1024, + temperature=0.3, + top_p=0.7, + stop=["<|eot_id|>", "<|eom_id|>"], + stream=False, + ) + + return response.choices[0].message.content + + +class ChatDB: + @staticmethod + def create_tables() -> None: + """ + Create the ChatDB.Chat_history and ChatDB.Chat_data tables + if they do not exist. Also, create a trigger to update is_stream + in Chat_data when Chat_history.is_stream is updated. + + Example: + >>> ChatDB.create_tables() + Tables and trigger created successfully + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ChatDB.Chat_history ( + chat_id INT AUTO_INCREMENT PRIMARY KEY, + start_time DATETIME, + is_stream INT + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ChatDB.Chat_data ( + id INT AUTO_INCREMENT PRIMARY KEY, + chat_id INT, + user TEXT, + assistant TEXT, + FOREIGN KEY (chat_id) REFERENCES ChatDB.Chat_history(chat_id) + ) + """ + ) + + cursor.execute("DROP TRIGGER IF EXISTS update_is_stream;") + + cursor.execute( + """ + CREATE TRIGGER update_is_stream + AFTER UPDATE ON ChatDB.Chat_history + FOR EACH ROW + BEGIN + UPDATE ChatDB.Chat_data + SET is_stream = NEW.is_stream + WHERE chat_id = NEW.chat_id; + END; + """ + ) + + conn.commit() + print("Tables and trigger created successfully") + except mysql.connector.Error as err: + print(f"Error: {err}") + finally: + cursor.close() + conn.close() + + @staticmethod + def insert_chat_history(start_time: datetime.datetime, is_stream: int) -> int: + """ + Insert a new row into the ChatDB.Chat_history table and return the inserted chat_id. + + Example: + >>> from datetime import datetime + >>> chat_id = ChatDB.insert_chat_history(datetime(2024, 1, 1, 12, 0, 0), 1) + >>> isinstance(chat_id, int) + True + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO ChatDB.Chat_history (start_time, is_stream) + VALUES (%s, %s) + """, + (start_time, is_stream), + ) + conn.commit() + cursor.execute("SELECT LAST_INSERT_ID()") + chat_id = cursor.fetchone()[0] + print("Chat history inserted successfully.") + return chat_id + except mysql.connector.Error as err: + print(f"Error: {err}") + return None + finally: + cursor.close() + conn.close() + + @staticmethod + def get_latest_chat_id() -> int: + """ + Retrieve the latest chat_id from the ChatDB.Chat_history table. + :return: The latest chat_id or None if no chat_id exists. + + Example: + >>> chat_id = ChatDB.get_latest_chat_id() + >>> isinstance(chat_id, int) + True + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + cursor.execute( + """ + SELECT chat_id FROM ChatDB.Chat_history + ORDER BY chat_id DESC LIMIT 1 + """ + ) + chat_id = cursor.fetchone()[0] + return chat_id if chat_id else None + except mysql.connector.Error as err: + print(f"Error: {err}") + return None + finally: + cursor.close() + conn.close() + + @staticmethod + def insert_chat_data(chat_id: int, user_message: str, assistant_message: str) -> None: + """ + Insert a new row into the ChatDB.Chat_data table. + + Example: + >>> ChatDB.insert_chat_data(1, 'Hello', 'Hi there!') + Chat data inserted successfully. + """ + try: + conn = mysql.connector.connect(**db_config) + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO ChatDB.Chat_data (chat_id, user, assistant) + VALUES (%s, %s, %s) + """, + (chat_id, user_message, assistant_message), + ) + conn.commit() + print("Chat data inserted successfully.") + except mysql.connector.Error as err: + print(f"Error: {err}") + finally: + cursor.close() + conn.close() + + +class Chatbot: + def __init__(self, api_service: str): + self.llm_service = LLMService(api_service) + self.conversation_history = [] + self.chat_id_pk = None + self.start_time = datetime.datetime.now(datetime.timezone.utc) + + def chat_session(self) -> None: + """ + Start a chatbot session, allowing the user to interact with the LLM. + Saves conversation history in the database and ends the session on "/stop" command. + + Example: + >>> chatbot = Chatbot(api_service="Groq") + >>> chatbot.chat_session() # This will be mocked in the tests + Welcome to the chatbot! Type '/stop' to end the conversation. + """ + print("Welcome to the chatbot! Type '/stop' to end the conversation.") + + while True: + user_input = input("\nYou: ").strip() + self.conversation_history.append({"role": "user", "content": user_input}) + + if self.chat_id_pk is None: + if user_input.lower() == "/stop": + break + bot_response = self.llm_service.generate_response(self.conversation_history) + self.conversation_history.append({"role": "assistant", "content": bot_response}) + + is_stream = 1 # New conversation + self.chat_id_pk = ChatDB.insert_chat_history(self.start_time, is_stream) # Return the chat_id + if self.chat_id_pk: + ChatDB.insert_chat_data(self.chat_id_pk, user_input, bot_response) + else: + if user_input.lower() == "/stop": + is_stream = 2 # End of conversation + current_time = datetime.datetime.now(datetime.timezone.utc) + ChatDB.insert_chat_history(current_time, is_stream) + break + + bot_response = self.llm_service.generate_response(self.conversation_history) + self.conversation_history.append({"role": "assistant", "content": bot_response}) + + is_stream = 0 # Continuation of conversation current_time = datetime.datetime.now(datetime.timezone.utc) - insert_chat_history(current_time, is_stream) - break + ChatDB.insert_chat_history(current_time, is_stream) + ChatDB.insert_chat_data(self.chat_id_pk, user_input, bot_response) - bot_response = generate_llm_response(conversation_history, api_service) - conversation_history.append({"role": "assistant", "content": bot_response}) + if len(self.conversation_history) > 1000: + self.conversation_history = self.conversation_history[-3:] - is_stream = 0 # Continuation of conversation - current_time = datetime.datetime.now(datetime.timezone.utc) - insert_chat_history(current_time, is_stream) - insert_chat_data(chat_id_pk, user_input, bot_response) +# Test cases for Chatbot +class TestChatbot(unittest.TestCase): - if len(conversation_history) > 1000: - conversation_history = conversation_history[-3:] + @patch('builtins.input', side_effect=["Hello", "/stop"]) + @patch('sys.stdout', new_callable=StringIO) + def test_chat_session(self, mock_stdout, mock_input): + """ + Test the chat_session method for expected welcome message. + """ + chatbot = Chatbot(api_service="Groq") + chatbot.chat_session() + + # Check for the welcome message in the output + output = mock_stdout.getvalue().strip().splitlines() + self.assertIn("Welcome to the chatbot! Type '/stop' to end the conversation.", output) + self.assertTrue(any("Chat history inserted successfully." in line for line in output)) + self.assertTrue(any("Chat data inserted successfully." in line for line in output)) - -# starting a chat session -create_tables() -chat_session() +if __name__ == "__main__": + # + ChatDB.create_tables() + unittest.main()