From 2dad12b898385d16e4e1a075d2613922ac62e0d5 Mon Sep 17 00:00:00 2001 From: Pritam Das <69068731+Pritam3355@users.noreply.github.com> Date: Sat, 19 Oct 2024 10:12:46 +0530 Subject: [PATCH] Add files via upload --- neural_network/chatbot/chatbot.py | 134 ++++++++++++++++ neural_network/chatbot/db.py | 199 ++++++++++++++++++++++++ neural_network/chatbot/llm_service.py | 78 ++++++++++ neural_network/chatbot/main.py | 44 ++++++ neural_network/chatbot/requirements.txt | 57 +++++++ 5 files changed, 512 insertions(+) create mode 100644 neural_network/chatbot/chatbot.py create mode 100644 neural_network/chatbot/db.py create mode 100644 neural_network/chatbot/llm_service.py create mode 100644 neural_network/chatbot/main.py create mode 100644 neural_network/chatbot/requirements.txt diff --git a/neural_network/chatbot/chatbot.py b/neural_network/chatbot/chatbot.py new file mode 100644 index 000000000..38488349f --- /dev/null +++ b/neural_network/chatbot/chatbot.py @@ -0,0 +1,134 @@ +import datetime +from typing import List, Dict, Any + + +class Chatbot: + """ + A Chatbot class to manage chat conversations using an LLM service and a database to store chat data. + + Methods: + - start_chat: Starts a new conversation, logs the start time. + - handle_user_message: Processes user input and stores user message & bot response in DB. + - end_chat: Ends the conversation and logs the end time. + - continue_chat: Retains only the last few messages if the conversation exceeds 1000 messages. + """ + + def __init__(self, db: Any, llm_service: Any) -> None: + """ + Initialize the Chatbot with a database and an LLM service. + + Parameters: + - db: The database instance used for storing chat data. + - llm_service: The language model service for generating responses. + """ + self.db = db + self.llm_service = llm_service + self.conversation_history: List[Dict[str, str]] = [] + self.chat_id_pk: int = None + + def start_chat(self) -> None: + """ + Start a new chat session and insert chat history to the database. + """ + start_time = datetime.datetime.now() + is_stream = 1 # Start new conversation + self.db.insert_chat_history(start_time, is_stream) + self.chat_id_pk = self.db.get_latest_chat_id() + + def handle_user_message(self, user_input: str) -> str: + """ + Handle user input and generate a bot response. + If the user sends '/stop', the conversation is terminated. + + Parameters: + - user_input: The input provided by the user. + + Returns: + - bot_response: The response generated by the bot. + + Raises: + - ValueError: If user input is not a string or if no chat_id is available. + + Doctest: + >>> class MockDatabase: + ... def __init__(self): + ... self.data = [] + ... def insert_chat_data(self, *args, **kwargs): + ... pass + ... def insert_chat_history(self, *args, **kwargs): + ... pass + ... def get_latest_chat_id(self): + ... return 1 + ... + >>> class MockLLM: + ... def generate_response(self, conversation_history): + ... if conversation_history[-1]["content"] == "/stop": + ... return "conversation-terminated" + ... return "Mock response" + >>> db_mock = MockDatabase() + >>> llm_mock = MockLLM() + >>> bot = Chatbot(db_mock, llm_mock) + >>> bot.start_chat() + >>> bot.handle_user_message("/stop") + 'conversation-terminated' + >>> bot.handle_user_message("Hello!") + 'Mock response' + """ + if not isinstance(user_input, str): + raise ValueError("User input must be a string.") + + if self.chat_id_pk is None: + raise ValueError("Chat has not been started. Call start_chat() first.") + + self.conversation_history.append({"role": "user", "content": user_input}) + + if user_input == "/stop": + self.end_chat() + return "conversation-terminated" + else: + bot_response = self.llm_service.generate_response(self.conversation_history) + print(f"Bot : ",bot_response) + self.conversation_history.append( + {"role": "assistant", "content": bot_response} + ) + self._store_message_in_db(user_input, bot_response) + + return bot_response + + def _store_message_in_db(self, user_input: str, bot_response: str) -> None: + """ + Store user input and bot response in the database. + + Parameters: + - user_input: The message from the user. + - bot_response: The response generated by the bot. + + Raises: + - ValueError: If insertion into the database fails. + """ + try: + self.db.insert_chat_data(self.chat_id_pk, user_input, bot_response) + except Exception as e: + raise ValueError(f"Failed to insert chat data: {e}") + + def end_chat(self) -> None: + """ + End the chat session and update the chat history in the database. + """ + current_time = datetime.datetime.now() + is_stream = 2 # End of conversation + try: + user_input = "/stop" + bot_response = "conversation-terminated" + print(f"Bot : ",bot_response) + self.db.insert_chat_data(self.chat_id_pk, user_input, bot_response) + self.db.insert_chat_history(current_time, is_stream) + except Exception as e: + raise ValueError(f"Failed to update chat history: {e}") + + def continue_chat(self) -> None: + """ + Retain only the last few entries if the conversation exceeds 1000 messages. + """ + if len(self.conversation_history) > 1000: + self.conversation_history = self.conversation_history[-3:] diff --git a/neural_network/chatbot/db.py b/neural_network/chatbot/db.py new file mode 100644 index 000000000..92ef6909c --- /dev/null +++ b/neural_network/chatbot/db.py @@ -0,0 +1,199 @@ +import os +from dotenv import load_dotenv +import mysql.connector +from mysql.connector import MySQLConnection + +load_dotenv() + + +class Database: + """ + A class to manage the connection to the MySQL database using configuration from environment variables. + + Attributes: + ----------- + config : dict + The database connection parameters like user, password, host, and database name. + """ + + def __init__(self) -> None: + self.config = { + "user": os.environ.get("DB_USER"), + "password": os.environ.get("DB_PASSWORD"), + "host": os.environ.get("DB_HOST"), + "database": os.environ.get("DB_NAME"), + } + + def connect(self) -> MySQLConnection: + """ + Establish a connection to the MySQL database. + + Returns: + -------- + MySQLConnection + A connection object for interacting with the MySQL database. + + Raises: + ------- + mysql.connector.Error + If the connection to the database fails. + """ + return mysql.connector.connect(**self.config) + + +class ChatDatabase: + """ + A class to manage chat-related database operations, such as creating tables, + inserting chat history, and retrieving chat data. + + Attributes: + ----------- + db : Database + An instance of the `Database` class for establishing connections to the MySQL database. + """ + + def __init__(self, db: Database) -> None: + self.db = db + + def create_tables(self) -> None: + """ + Create the necessary tables for chat history and chat data in the database. + If the tables already exist, they will not be created again. + + Raises: + ------- + mysql.connector.Error + If there is any error executing the SQL statements. + """ + conn = self.db.connect() + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ChatDB.Chat_history ( + chat_id INT AUTO_INCREMENT PRIMARY KEY, + start_time DATETIME, + is_stream INT + ) + """ + ) + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS ChatDB.Chat_data ( + id INT AUTO_INCREMENT PRIMARY KEY, + chat_id INT, + user TEXT, + assistant TEXT, + FOREIGN KEY (chat_id) REFERENCES ChatDB.Chat_history(chat_id) + ) + """ + ) + + cursor.execute("DROP TRIGGER IF EXISTS update_is_stream") + + cursor.execute( + """ + CREATE TRIGGER update_is_stream + AFTER UPDATE ON ChatDB.Chat_history + FOR EACH ROW + BEGIN + UPDATE ChatDB.Chat_data + SET is_stream = NEW.is_stream + WHERE chat_id = NEW.chat_id; + END; + """ + ) + + conn.commit() + cursor.close() + conn.close() + + def insert_chat_history(self, start_time: str, is_stream: int) -> None: + """ + Insert a new chat history record into the database. + + Parameters: + ----------- + start_time : str + The starting time of the chat session. + is_stream : int + An integer indicating whether the chat is in progress (1) or ended (2). + + Raises: + ------- + mysql.connector.Error + If there is any error executing the SQL statements. + """ + conn = self.db.connect() + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO ChatDB.Chat_history (start_time, is_stream) + VALUES (%s, %s) + """, + (start_time, is_stream), + ) + conn.commit() + cursor.close() + conn.close() + + def get_latest_chat_id(self) -> int: + """ + Retrieve the chat ID of the most recent chat session from the database. + + Returns: + -------- + int + The ID of the latest chat session. + + Raises: + ------- + mysql.connector.Error + If there is any error executing the SQL statements. + """ + conn = self.db.connect() + cursor = conn.cursor() + cursor.execute( + """ + SELECT chat_id FROM ChatDB.Chat_history WHERE + chat_id=(SELECT MAX(chat_id) FROM ChatDB.Chat_history) + """ + ) + chat_id_pk = cursor.fetchone()[0] + cursor.close() + conn.close() + return chat_id_pk + + def insert_chat_data( + self, chat_id: int, user_message: str, assistant_message: str + ) -> None: + """ + Insert a new chat data record into the database. + + Parameters: + ----------- + chat_id : int + The ID of the chat session to which this data belongs. + user_message : str + The message provided by the user in the chat session. + assistant_message : str + The response from the assistant in the chat session. + + Raises: + ------- + mysql.connector.Error + If there is any error executing the SQL statements. + """ + conn = self.db.connect() + cursor = conn.cursor() + cursor.execute( + """ + INSERT INTO ChatDB.Chat_data (chat_id, user, assistant) + VALUES (%s, %s, %s) + """, + (chat_id, user_message, assistant_message), + ) + conn.commit() + cursor.close() + conn.close() diff --git a/neural_network/chatbot/llm_service.py b/neural_network/chatbot/llm_service.py new file mode 100644 index 000000000..f1203f642 --- /dev/null +++ b/neural_network/chatbot/llm_service.py @@ -0,0 +1,78 @@ +import os +from together import Together +from groq import Groq +from dotenv import load_dotenv +from typing import List, Dict + +load_dotenv() + + +class LLMService: + """ + A class to interact with different LLM (Large Language Model) API services, such as Together and Groq. + + Attributes: + ----------- + api_service : str + The name of the API service to use ("Together" or "Groq"). + """ + + def __init__(self, api_service: str) -> None: + """ + Initialize the LLMService with a specific API service. + + Parameters: + ----------- + api_service : str + The name of the LLM API service, either "Together" or "Groq". + """ + self.api_service = api_service + + def generate_response(self, conversation_history: List[Dict[str, str]]) -> str: + """ + Generate a response from the specified LLM API based on the conversation history. + + Parameters: + ----------- + conversation_history : List[Dict[str, str]] + The list of conversation messages, where each message is a dictionary with 'role' and 'content' keys. + + Returns: + -------- + str + The generated response content from the assistant. + + Raises: + ------- + ValueError + If the specified API service is neither "Together" nor "Groq". + """ + if self.api_service == "Together": + client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) + response = client.chat.completions.create( + model="meta-llama/Llama-3.2-3B-Instruct-Turbo", + messages=conversation_history, + max_tokens=512, + temperature=0.3, + top_p=0.7, + top_k=50, + repetition_penalty=1, + stop=["<|eot_id|>", "<|eom_id|>"], + stream=False, + ) + elif self.api_service == "Groq": + client = Groq(api_key=os.environ.get("GROQ_API_KEY")) + response = client.chat.completions.create( + model="llama3-8b-8192", + messages=conversation_history, + max_tokens=1024, + temperature=0.3, + top_p=0.7, + stop=["<|eot_id|>", "<|eom_id|>"], + stream=False, + ) + else: + raise ValueError(f"Unsupported API service: {self.api_service}") + + # Extracting the content of the generated response + return response.choices[0].message.content diff --git a/neural_network/chatbot/main.py b/neural_network/chatbot/main.py new file mode 100644 index 000000000..cdbd631c7 --- /dev/null +++ b/neural_network/chatbot/main.py @@ -0,0 +1,44 @@ +from db import Database, ChatDatabase +from llm_service import LLMService +from chatbot import Chatbot +from typing import NoReturn + + +def main() -> NoReturn: + """ + Main function to initialize and start the chatbot application. + + This function initializes the database and LLM service, creates necessary tables, and starts + the chatbot for user interaction. + """ + # Initialize and configure the database + db = Database() + chat_db = ChatDatabase(db) + chat_db.create_tables() + + # Set the API service to either "Together" or "Groq" + api_service = ( + "Groq" # Can be set dynamically based on user preference or environment + ) + llm_service = LLMService(api_service) + + # Initialize the Chatbot with the database and LLM service + chatbot = Chatbot(chat_db, llm_service) + + print("Welcome to the chatbot! Type '/stop' to end the conversation.") + chatbot.start_chat() + + # Chat loop to handle user input + while True: + user_input = input("\nYou: ") + if user_input.strip().lower() == "/stop": + chatbot.end_chat() # End the conversation if user types "/stop" + break + chatbot.handle_user_message( + user_input + ) # Process user input and generate response + chatbot.continue_chat() # Handle long conversations (trim history if necessary) + + +if __name__ == "__main__": + main() diff --git a/neural_network/chatbot/requirements.txt b/neural_network/chatbot/requirements.txt new file mode 100644 index 000000000..0f1204243 --- /dev/null +++ b/neural_network/chatbot/requirements.txt @@ -0,0 +1,57 @@ +aiohappyeyeballs==2.4.2 +aiohttp==3.10.8 +aiosignal==1.3.1 +annotated-types==0.7.0 +anyio==4.6.0 +asgiref==3.8.1 +attrs==24.2.0 +black==24.10.0 +certifi==2024.8.30 +cfgv==3.4.0 +charset-normalizer==3.3.2 +click==8.1.7 +distlib==0.3.9 +distro==1.9.0 +Django==5.1.1 +djangorestframework==3.15.2 +eval_type_backport==0.2.0 +filelock==3.16.1 +frozenlist==1.4.1 +groq==0.11.0 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.2 +identify==2.6.1 +idna==3.10 +markdown-it-py==3.0.0 +mdurl==0.1.2 +multidict==6.1.0 +mypy-extensions==1.0.0 +mysql-connector-python==9.0.0 +nodeenv==1.9.1 +numpy==2.1.1 +packaging==24.1 +pathspec==0.12.1 +pillow==10.4.0 +platformdirs==4.3.6 +pre_commit==4.0.1 +pyarrow==17.0.0 +pydantic==2.9.2 +pydantic_core==2.23.4 +Pygments==2.18.0 +python-dotenv==1.0.1 +PyYAML==6.0.2 +requests==2.32.3 +rich==13.8.1 +ruff==0.7.0 +shellingham==1.5.4 +sniffio==1.3.1 +sqlparse==0.5.1 +tabulate==0.9.0 +together==1.3.0 +tqdm==4.66.5 +typer==0.12.5 +typing_extensions==4.12.2 +urllib3==2.2.3 +virtualenv==20.27.0 +yarl==1.13.1