mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-03-17 20:19:48 +00:00
Add files via upload
This commit is contained in:
parent
f3d43e8694
commit
2dad12b898
134
neural_network/chatbot/chatbot.py
Normal file
134
neural_network/chatbot/chatbot.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
import datetime
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class Chatbot:
|
||||||
|
"""
|
||||||
|
A Chatbot class to manage chat conversations using an LLM service and a database to store chat data.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
- start_chat: Starts a new conversation, logs the start time.
|
||||||
|
- handle_user_message: Processes user input and stores user message & bot response in DB.
|
||||||
|
- end_chat: Ends the conversation and logs the end time.
|
||||||
|
- continue_chat: Retains only the last few messages if the conversation exceeds 1000 messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Any, llm_service: Any) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the Chatbot with a database and an LLM service.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- db: The database instance used for storing chat data.
|
||||||
|
- llm_service: The language model service for generating responses.
|
||||||
|
"""
|
||||||
|
self.db = db
|
||||||
|
self.llm_service = llm_service
|
||||||
|
self.conversation_history: List[Dict[str, str]] = []
|
||||||
|
self.chat_id_pk: int = None
|
||||||
|
|
||||||
|
def start_chat(self) -> None:
|
||||||
|
"""
|
||||||
|
Start a new chat session and insert chat history to the database.
|
||||||
|
"""
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
is_stream = 1 # Start new conversation
|
||||||
|
self.db.insert_chat_history(start_time, is_stream)
|
||||||
|
self.chat_id_pk = self.db.get_latest_chat_id()
|
||||||
|
|
||||||
|
def handle_user_message(self, user_input: str) -> str:
|
||||||
|
"""
|
||||||
|
Handle user input and generate a bot response.
|
||||||
|
If the user sends '/stop', the conversation is terminated.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- user_input: The input provided by the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- bot_response: The response generated by the bot.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- ValueError: If user input is not a string or if no chat_id is available.
|
||||||
|
|
||||||
|
Doctest:
|
||||||
|
>>> class MockDatabase:
|
||||||
|
... def __init__(self):
|
||||||
|
... self.data = []
|
||||||
|
... def insert_chat_data(self, *args, **kwargs):
|
||||||
|
... pass
|
||||||
|
... def insert_chat_history(self, *args, **kwargs):
|
||||||
|
... pass
|
||||||
|
... def get_latest_chat_id(self):
|
||||||
|
... return 1
|
||||||
|
...
|
||||||
|
>>> class MockLLM:
|
||||||
|
... def generate_response(self, conversation_history):
|
||||||
|
... if conversation_history[-1]["content"] == "/stop":
|
||||||
|
... return "conversation-terminated"
|
||||||
|
... return "Mock response"
|
||||||
|
>>> db_mock = MockDatabase()
|
||||||
|
>>> llm_mock = MockLLM()
|
||||||
|
>>> bot = Chatbot(db_mock, llm_mock)
|
||||||
|
>>> bot.start_chat()
|
||||||
|
>>> bot.handle_user_message("/stop")
|
||||||
|
'conversation-terminated'
|
||||||
|
>>> bot.handle_user_message("Hello!")
|
||||||
|
'Mock response'
|
||||||
|
"""
|
||||||
|
if not isinstance(user_input, str):
|
||||||
|
raise ValueError("User input must be a string.")
|
||||||
|
|
||||||
|
if self.chat_id_pk is None:
|
||||||
|
raise ValueError("Chat has not been started. Call start_chat() first.")
|
||||||
|
|
||||||
|
self.conversation_history.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
|
if user_input == "/stop":
|
||||||
|
self.end_chat()
|
||||||
|
return "conversation-terminated"
|
||||||
|
else:
|
||||||
|
bot_response = self.llm_service.generate_response(self.conversation_history)
|
||||||
|
print(f"Bot : ",bot_response)
|
||||||
|
self.conversation_history.append(
|
||||||
|
{"role": "assistant", "content": bot_response}
|
||||||
|
)
|
||||||
|
self._store_message_in_db(user_input, bot_response)
|
||||||
|
|
||||||
|
return bot_response
|
||||||
|
|
||||||
|
def _store_message_in_db(self, user_input: str, bot_response: str) -> None:
|
||||||
|
"""
|
||||||
|
Store user input and bot response in the database.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- user_input: The message from the user.
|
||||||
|
- bot_response: The response generated by the bot.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- ValueError: If insertion into the database fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.db.insert_chat_data(self.chat_id_pk, user_input, bot_response)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to insert chat data: {e}")
|
||||||
|
|
||||||
|
def end_chat(self) -> None:
|
||||||
|
"""
|
||||||
|
End the chat session and update the chat history in the database.
|
||||||
|
"""
|
||||||
|
current_time = datetime.datetime.now()
|
||||||
|
is_stream = 2 # End of conversation
|
||||||
|
try:
|
||||||
|
user_input = "/stop"
|
||||||
|
bot_response = "conversation-terminated"
|
||||||
|
print(f"Bot : ",bot_response)
|
||||||
|
self.db.insert_chat_data(self.chat_id_pk, user_input, bot_response)
|
||||||
|
self.db.insert_chat_history(current_time, is_stream)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to update chat history: {e}")
|
||||||
|
|
||||||
|
def continue_chat(self) -> None:
|
||||||
|
"""
|
||||||
|
Retain only the last few entries if the conversation exceeds 1000 messages.
|
||||||
|
"""
|
||||||
|
if len(self.conversation_history) > 1000:
|
||||||
|
self.conversation_history = self.conversation_history[-3:]
|
199
neural_network/chatbot/db.py
Normal file
199
neural_network/chatbot/db.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import mysql.connector
|
||||||
|
from mysql.connector import MySQLConnection
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""
|
||||||
|
A class to manage the connection to the MySQL database using configuration from environment variables.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
-----------
|
||||||
|
config : dict
|
||||||
|
The database connection parameters like user, password, host, and database name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.config = {
|
||||||
|
"user": os.environ.get("DB_USER"),
|
||||||
|
"password": os.environ.get("DB_PASSWORD"),
|
||||||
|
"host": os.environ.get("DB_HOST"),
|
||||||
|
"database": os.environ.get("DB_NAME"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def connect(self) -> MySQLConnection:
|
||||||
|
"""
|
||||||
|
Establish a connection to the MySQL database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
MySQLConnection
|
||||||
|
A connection object for interacting with the MySQL database.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
-------
|
||||||
|
mysql.connector.Error
|
||||||
|
If the connection to the database fails.
|
||||||
|
"""
|
||||||
|
return mysql.connector.connect(**self.config)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDatabase:
|
||||||
|
"""
|
||||||
|
A class to manage chat-related database operations, such as creating tables,
|
||||||
|
inserting chat history, and retrieving chat data.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
-----------
|
||||||
|
db : Database
|
||||||
|
An instance of the `Database` class for establishing connections to the MySQL database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db: Database) -> None:
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def create_tables(self) -> None:
|
||||||
|
"""
|
||||||
|
Create the necessary tables for chat history and chat data in the database.
|
||||||
|
If the tables already exist, they will not be created again.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
-------
|
||||||
|
mysql.connector.Error
|
||||||
|
If there is any error executing the SQL statements.
|
||||||
|
"""
|
||||||
|
conn = self.db.connect()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS ChatDB.Chat_history (
|
||||||
|
chat_id INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
start_time DATETIME,
|
||||||
|
is_stream INT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS ChatDB.Chat_data (
|
||||||
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
chat_id INT,
|
||||||
|
user TEXT,
|
||||||
|
assistant TEXT,
|
||||||
|
FOREIGN KEY (chat_id) REFERENCES ChatDB.Chat_history(chat_id)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute("DROP TRIGGER IF EXISTS update_is_stream")
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER update_is_stream
|
||||||
|
AFTER UPDATE ON ChatDB.Chat_history
|
||||||
|
FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE ChatDB.Chat_data
|
||||||
|
SET is_stream = NEW.is_stream
|
||||||
|
WHERE chat_id = NEW.chat_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def insert_chat_history(self, start_time: str, is_stream: int) -> None:
|
||||||
|
"""
|
||||||
|
Insert a new chat history record into the database.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
start_time : str
|
||||||
|
The starting time of the chat session.
|
||||||
|
is_stream : int
|
||||||
|
An integer indicating whether the chat is in progress (1) or ended (2).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
-------
|
||||||
|
mysql.connector.Error
|
||||||
|
If there is any error executing the SQL statements.
|
||||||
|
"""
|
||||||
|
conn = self.db.connect()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO ChatDB.Chat_history (start_time, is_stream)
|
||||||
|
VALUES (%s, %s)
|
||||||
|
""",
|
||||||
|
(start_time, is_stream),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_latest_chat_id(self) -> int:
|
||||||
|
"""
|
||||||
|
Retrieve the chat ID of the most recent chat session from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
int
|
||||||
|
The ID of the latest chat session.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
-------
|
||||||
|
mysql.connector.Error
|
||||||
|
If there is any error executing the SQL statements.
|
||||||
|
"""
|
||||||
|
conn = self.db.connect()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT chat_id FROM ChatDB.Chat_history WHERE
|
||||||
|
chat_id=(SELECT MAX(chat_id) FROM ChatDB.Chat_history)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
chat_id_pk = cursor.fetchone()[0]
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
return chat_id_pk
|
||||||
|
|
||||||
|
def insert_chat_data(
|
||||||
|
self, chat_id: int, user_message: str, assistant_message: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Insert a new chat data record into the database.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
chat_id : int
|
||||||
|
The ID of the chat session to which this data belongs.
|
||||||
|
user_message : str
|
||||||
|
The message provided by the user in the chat session.
|
||||||
|
assistant_message : str
|
||||||
|
The response from the assistant in the chat session.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
-------
|
||||||
|
mysql.connector.Error
|
||||||
|
If there is any error executing the SQL statements.
|
||||||
|
"""
|
||||||
|
conn = self.db.connect()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO ChatDB.Chat_data (chat_id, user, assistant)
|
||||||
|
VALUES (%s, %s, %s)
|
||||||
|
""",
|
||||||
|
(chat_id, user_message, assistant_message),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
78
neural_network/chatbot/llm_service.py
Normal file
78
neural_network/chatbot/llm_service.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import os
|
||||||
|
from together import Together
|
||||||
|
from groq import Groq
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class LLMService:
|
||||||
|
"""
|
||||||
|
A class to interact with different LLM (Large Language Model) API services, such as Together and Groq.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
-----------
|
||||||
|
api_service : str
|
||||||
|
The name of the API service to use ("Together" or "Groq").
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_service: str) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the LLMService with a specific API service.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
api_service : str
|
||||||
|
The name of the LLM API service, either "Together" or "Groq".
|
||||||
|
"""
|
||||||
|
self.api_service = api_service
|
||||||
|
|
||||||
|
def generate_response(self, conversation_history: List[Dict[str, str]]) -> str:
|
||||||
|
"""
|
||||||
|
Generate a response from the specified LLM API based on the conversation history.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
conversation_history : List[Dict[str, str]]
|
||||||
|
The list of conversation messages, where each message is a dictionary with 'role' and 'content' keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
str
|
||||||
|
The generated response content from the assistant.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
-------
|
||||||
|
ValueError
|
||||||
|
If the specified API service is neither "Together" nor "Groq".
|
||||||
|
"""
|
||||||
|
if self.api_service == "Together":
|
||||||
|
client = Together(api_key=os.environ.get("TOGETHER_API_KEY"))
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||||
|
messages=conversation_history,
|
||||||
|
max_tokens=512,
|
||||||
|
temperature=0.3,
|
||||||
|
top_p=0.7,
|
||||||
|
top_k=50,
|
||||||
|
repetition_penalty=1,
|
||||||
|
stop=["<|eot_id|>", "<|eom_id|>"],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
elif self.api_service == "Groq":
|
||||||
|
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="llama3-8b-8192",
|
||||||
|
messages=conversation_history,
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.3,
|
||||||
|
top_p=0.7,
|
||||||
|
stop=["<|eot_id|>", "<|eom_id|>"],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported API service: {self.api_service}")
|
||||||
|
|
||||||
|
# Extracting the content of the generated response
|
||||||
|
return response.choices[0].message.content
|
44
neural_network/chatbot/main.py
Normal file
44
neural_network/chatbot/main.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from db import Database, ChatDatabase
|
||||||
|
from llm_service import LLMService
|
||||||
|
from chatbot import Chatbot
|
||||||
|
from typing import NoReturn
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> NoReturn:
|
||||||
|
"""
|
||||||
|
Main function to initialize and start the chatbot application.
|
||||||
|
|
||||||
|
This function initializes the database and LLM service, creates necessary tables, and starts
|
||||||
|
the chatbot for user interaction.
|
||||||
|
"""
|
||||||
|
# Initialize and configure the database
|
||||||
|
db = Database()
|
||||||
|
chat_db = ChatDatabase(db)
|
||||||
|
chat_db.create_tables()
|
||||||
|
|
||||||
|
# Set the API service to either "Together" or "Groq"
|
||||||
|
api_service = (
|
||||||
|
"Groq" # Can be set dynamically based on user preference or environment
|
||||||
|
)
|
||||||
|
llm_service = LLMService(api_service)
|
||||||
|
|
||||||
|
# Initialize the Chatbot with the database and LLM service
|
||||||
|
chatbot = Chatbot(chat_db, llm_service)
|
||||||
|
|
||||||
|
print("Welcome to the chatbot! Type '/stop' to end the conversation.")
|
||||||
|
chatbot.start_chat()
|
||||||
|
|
||||||
|
# Chat loop to handle user input
|
||||||
|
while True:
|
||||||
|
user_input = input("\nYou: ")
|
||||||
|
if user_input.strip().lower() == "/stop":
|
||||||
|
chatbot.end_chat() # End the conversation if user types "/stop"
|
||||||
|
break
|
||||||
|
chatbot.handle_user_message(
|
||||||
|
user_input
|
||||||
|
) # Process user input and generate response
|
||||||
|
chatbot.continue_chat() # Handle long conversations (trim history if necessary)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
57
neural_network/chatbot/requirements.txt
Normal file
57
neural_network/chatbot/requirements.txt
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
aiohappyeyeballs==2.4.2
|
||||||
|
aiohttp==3.10.8
|
||||||
|
aiosignal==1.3.1
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.6.0
|
||||||
|
asgiref==3.8.1
|
||||||
|
attrs==24.2.0
|
||||||
|
black==24.10.0
|
||||||
|
certifi==2024.8.30
|
||||||
|
cfgv==3.4.0
|
||||||
|
charset-normalizer==3.3.2
|
||||||
|
click==8.1.7
|
||||||
|
distlib==0.3.9
|
||||||
|
distro==1.9.0
|
||||||
|
Django==5.1.1
|
||||||
|
djangorestframework==3.15.2
|
||||||
|
eval_type_backport==0.2.0
|
||||||
|
filelock==3.16.1
|
||||||
|
frozenlist==1.4.1
|
||||||
|
groq==0.11.0
|
||||||
|
h11==0.14.0
|
||||||
|
httpcore==1.0.5
|
||||||
|
httpx==0.27.2
|
||||||
|
identify==2.6.1
|
||||||
|
idna==3.10
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
mdurl==0.1.2
|
||||||
|
multidict==6.1.0
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
mysql-connector-python==9.0.0
|
||||||
|
nodeenv==1.9.1
|
||||||
|
numpy==2.1.1
|
||||||
|
packaging==24.1
|
||||||
|
pathspec==0.12.1
|
||||||
|
pillow==10.4.0
|
||||||
|
platformdirs==4.3.6
|
||||||
|
pre_commit==4.0.1
|
||||||
|
pyarrow==17.0.0
|
||||||
|
pydantic==2.9.2
|
||||||
|
pydantic_core==2.23.4
|
||||||
|
Pygments==2.18.0
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
PyYAML==6.0.2
|
||||||
|
requests==2.32.3
|
||||||
|
rich==13.8.1
|
||||||
|
ruff==0.7.0
|
||||||
|
shellingham==1.5.4
|
||||||
|
sniffio==1.3.1
|
||||||
|
sqlparse==0.5.1
|
||||||
|
tabulate==0.9.0
|
||||||
|
together==1.3.0
|
||||||
|
tqdm==4.66.5
|
||||||
|
typer==0.12.5
|
||||||
|
typing_extensions==4.12.2
|
||||||
|
urllib3==2.2.3
|
||||||
|
virtualenv==20.27.0
|
||||||
|
yarl==1.13.1
|
Loading…
x
Reference in New Issue
Block a user