mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-03-17 20:19:48 +00:00
Add files via upload
This commit is contained in:
parent
789e975148
commit
3e4430dfb6
@ -1,14 +1,12 @@
|
|||||||
"""
|
|
||||||
credits : https://medium.com/google-developer-experts/beyond-live-sessions-building-persistent-memory-chatbots-with-langchain-gemini-pro-and-firebase-19d6f84e21d3
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
from dotenv import load_dotenv
|
|
||||||
import mysql.connector
|
import mysql.connector
|
||||||
|
from dotenv import load_dotenv
|
||||||
from together import Together
|
from together import Together
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
@ -20,227 +18,271 @@ db_config = {
|
|||||||
"database": os.environ.get("DB_NAME"),
|
"database": os.environ.get("DB_NAME"),
|
||||||
}
|
}
|
||||||
|
|
||||||
api_service = os.environ.get("API_SERVICE")
|
class LLMService:
|
||||||
|
def __init__(self, api_service: str):
|
||||||
|
self.api_service = api_service
|
||||||
def create_tables() -> None:
|
if self.api_service == "Together":
|
||||||
"""
|
self.client = Together(api_key=os.environ.get("TOGETHER_API_KEY"))
|
||||||
Create the ChatDB.Chat_history and ChatDB.Chat_data tables
|
|
||||||
if they do not exist.Also, create a trigger to update is_stream
|
|
||||||
in Chat_data when Chat_history.is_stream is updated.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
conn = mysql.connector.connect(**db_config)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS ChatDB.Chat_history (
|
|
||||||
chat_id INT AUTO_INCREMENT PRIMARY KEY,
|
|
||||||
start_time DATETIME,
|
|
||||||
is_stream INT
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS ChatDB.Chat_data (
|
|
||||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
|
||||||
chat_id INT,
|
|
||||||
user TEXT,
|
|
||||||
assistant TEXT,
|
|
||||||
FOREIGN KEY (chat_id) REFERENCES ChatDB.Chat_history(chat_id)
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
cursor.execute("DROP TRIGGER IF EXISTS update_is_stream;")
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
CREATE TRIGGER update_is_stream
|
|
||||||
AFTER UPDATE ON ChatDB.Chat_history
|
|
||||||
FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE ChatDB.Chat_data
|
|
||||||
SET is_stream = NEW.is_stream
|
|
||||||
WHERE chat_id = NEW.chat_id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
conn.commit()
|
|
||||||
except mysql.connector.Error as err:
|
|
||||||
print(f"Error: {err}")
|
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
conn.close()
|
|
||||||
print("Tables and trigger created successfully")
|
|
||||||
|
|
||||||
|
|
||||||
def insert_chat_history(start_time: datetime.datetime, is_stream: int) -> None:
|
|
||||||
"""
|
|
||||||
Insert a new row into the ChatDB.Chat_history table.
|
|
||||||
:param start_time: Timestamp of when the chat started
|
|
||||||
:param is_stream: Indicator of whether the conversation is
|
|
||||||
ongoing, starting, or ending
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
conn = mysql.connector.connect(**db_config)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO ChatDB.Chat_history (start_time, is_stream)
|
|
||||||
VALUES (%s, %s)
|
|
||||||
""",
|
|
||||||
(start_time, is_stream),
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
except mysql.connector.Error as err:
|
|
||||||
print(f"Error: {err}")
|
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_latest_chat_id() -> int:
|
|
||||||
"""
|
|
||||||
Retrieve the latest chat_id from the ChatDB.Chat_history table.
|
|
||||||
:return: The latest chat_id or None if no chat_id exists.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
conn = mysql.connector.connect(**db_config)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
SELECT chat_id FROM ChatDB.Chat_history
|
|
||||||
ORDER BY chat_id DESC LIMIT 1
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
chat_id = cursor.fetchone()[0]
|
|
||||||
return chat_id if chat_id else None
|
|
||||||
except mysql.connector.Error as err:
|
|
||||||
print(f"Error: {err}")
|
|
||||||
return 0
|
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def insert_chat_data(chat_id: int, user_message: str, assistant_message: str) -> None:
|
|
||||||
"""
|
|
||||||
Insert a new row into the ChatDB.Chat_data table.
|
|
||||||
:param chat_id: The ID of the chat session
|
|
||||||
:param user_message: The user's message
|
|
||||||
:param assistant_message: The assistant's message
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
conn = mysql.connector.connect(**db_config)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO ChatDB.Chat_data (chat_id, user, assistant)
|
|
||||||
VALUES (%s, %s, %s)
|
|
||||||
""",
|
|
||||||
(chat_id, user_message, assistant_message),
|
|
||||||
)
|
|
||||||
conn.commit()
|
|
||||||
except mysql.connector.Error as err:
|
|
||||||
print(f"Error: {err}")
|
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_llm_response(
|
|
||||||
conversation_history: list[dict], api_service: str = "Groq"
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Generate a response from the LLM based on the conversation history.
|
|
||||||
:param conversation_history: List of dictionaries representing
|
|
||||||
the conversation so far
|
|
||||||
:param api_service: Choose between "Together" or "Groq" as the
|
|
||||||
API service
|
|
||||||
:return: Assistant's response as a string
|
|
||||||
"""
|
|
||||||
bot_response = ""
|
|
||||||
if api_service == "Together":
|
|
||||||
client = Together(api_key=os.environ.get("TOGETHER_API_KEY"))
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model="meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
|
||||||
messages=conversation_history,
|
|
||||||
max_tokens=512,
|
|
||||||
temperature=0.3,
|
|
||||||
top_p=0.7,
|
|
||||||
top_k=50,
|
|
||||||
repetition_penalty=1,
|
|
||||||
stop=["<|eot_id|>", "<|eom_id|>"],
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
bot_response = response.choices[0].message.content
|
|
||||||
else:
|
|
||||||
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model="llama3-8b-8192",
|
|
||||||
messages=conversation_history,
|
|
||||||
max_tokens=1024,
|
|
||||||
temperature=0.3,
|
|
||||||
top_p=0.7,
|
|
||||||
stop=["<|eot_id|>", "<|eom_id|>"],
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
bot_response = response.choices[0].message.content
|
|
||||||
|
|
||||||
return bot_response
|
|
||||||
|
|
||||||
|
|
||||||
def chat_session() -> None:
|
|
||||||
"""
|
|
||||||
Start a chatbot session, allowing the user to interact with the LLM.
|
|
||||||
Saves conversation history in the database and ends the session on "/stop" command.
|
|
||||||
"""
|
|
||||||
print("Welcome to the chatbot! Type '/stop' to end the conversation.")
|
|
||||||
|
|
||||||
conversation_history = []
|
|
||||||
start_time = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
|
|
||||||
chat_id_pk = None
|
|
||||||
api_service = "Groq" # or "Together"
|
|
||||||
|
|
||||||
while True:
|
|
||||||
user_input = input("\nYou: ").strip()
|
|
||||||
conversation_history.append({"role": "user", "content": user_input})
|
|
||||||
|
|
||||||
if chat_id_pk is None:
|
|
||||||
if user_input.lower() == "/stop":
|
|
||||||
break
|
|
||||||
bot_response = generate_llm_response(conversation_history, api_service)
|
|
||||||
conversation_history.append({"role": "assistant", "content": bot_response})
|
|
||||||
|
|
||||||
is_stream = 1 # New conversation
|
|
||||||
insert_chat_history(start_time, is_stream)
|
|
||||||
chat_id_pk = get_latest_chat_id()
|
|
||||||
insert_chat_data(chat_id_pk, user_input, bot_response)
|
|
||||||
else:
|
else:
|
||||||
if user_input.lower() == "/stop":
|
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
||||||
is_stream = 2 # End of conversation
|
|
||||||
|
def generate_response(self, conversation_history: list[dict]) -> str:
|
||||||
|
"""
|
||||||
|
Generate a response from the LLM based on the conversation history.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> llm_service = LLMService(api_service="Groq")
|
||||||
|
>>> response = llm_service.generate_response([{"role": "user", "content": "Hello"}])
|
||||||
|
>>> isinstance(response, str)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
if self.api_service == "Together":
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="meta-llama/Llama-3.2-3B-Instruct-Turbo",
|
||||||
|
messages=conversation_history,
|
||||||
|
max_tokens=512,
|
||||||
|
temperature=0.3,
|
||||||
|
top_p=0.7,
|
||||||
|
top_k=50,
|
||||||
|
repetition_penalty=1,
|
||||||
|
stop=["<|eot_id|>", "<|eom_id|>"],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="llama3-8b-8192",
|
||||||
|
messages=conversation_history,
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.3,
|
||||||
|
top_p=0.7,
|
||||||
|
stop=["<|eot_id|>", "<|eom_id|>"],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDB:
|
||||||
|
@staticmethod
|
||||||
|
def create_tables() -> None:
|
||||||
|
"""
|
||||||
|
Create the ChatDB.Chat_history and ChatDB.Chat_data tables
|
||||||
|
if they do not exist. Also, create a trigger to update is_stream
|
||||||
|
in Chat_data when Chat_history.is_stream is updated.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> ChatDB.create_tables()
|
||||||
|
Tables and trigger created successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
conn = mysql.connector.connect(**db_config)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS ChatDB.Chat_history (
|
||||||
|
chat_id INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
start_time DATETIME,
|
||||||
|
is_stream INT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS ChatDB.Chat_data (
|
||||||
|
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
chat_id INT,
|
||||||
|
user TEXT,
|
||||||
|
assistant TEXT,
|
||||||
|
FOREIGN KEY (chat_id) REFERENCES ChatDB.Chat_history(chat_id)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute("DROP TRIGGER IF EXISTS update_is_stream;")
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TRIGGER update_is_stream
|
||||||
|
AFTER UPDATE ON ChatDB.Chat_history
|
||||||
|
FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE ChatDB.Chat_data
|
||||||
|
SET is_stream = NEW.is_stream
|
||||||
|
WHERE chat_id = NEW.chat_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
print("Tables and trigger created successfully")
|
||||||
|
except mysql.connector.Error as err:
|
||||||
|
print(f"Error: {err}")
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def insert_chat_history(start_time: datetime.datetime, is_stream: int) -> int:
|
||||||
|
"""
|
||||||
|
Insert a new row into the ChatDB.Chat_history table and return the inserted chat_id.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from datetime import datetime
|
||||||
|
>>> chat_id = ChatDB.insert_chat_history(datetime(2024, 1, 1, 12, 0, 0), 1)
|
||||||
|
>>> isinstance(chat_id, int)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
conn = mysql.connector.connect(**db_config)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO ChatDB.Chat_history (start_time, is_stream)
|
||||||
|
VALUES (%s, %s)
|
||||||
|
""",
|
||||||
|
(start_time, is_stream),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
cursor.execute("SELECT LAST_INSERT_ID()")
|
||||||
|
chat_id = cursor.fetchone()[0]
|
||||||
|
print("Chat history inserted successfully.")
|
||||||
|
return chat_id
|
||||||
|
except mysql.connector.Error as err:
|
||||||
|
print(f"Error: {err}")
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_latest_chat_id() -> int:
|
||||||
|
"""
|
||||||
|
Retrieve the latest chat_id from the ChatDB.Chat_history table.
|
||||||
|
:return: The latest chat_id or None if no chat_id exists.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> chat_id = ChatDB.get_latest_chat_id()
|
||||||
|
>>> isinstance(chat_id, int)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
conn = mysql.connector.connect(**db_config)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT chat_id FROM ChatDB.Chat_history
|
||||||
|
ORDER BY chat_id DESC LIMIT 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
chat_id = cursor.fetchone()[0]
|
||||||
|
return chat_id if chat_id else None
|
||||||
|
except mysql.connector.Error as err:
|
||||||
|
print(f"Error: {err}")
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def insert_chat_data(chat_id: int, user_message: str, assistant_message: str) -> None:
|
||||||
|
"""
|
||||||
|
Insert a new row into the ChatDB.Chat_data table.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> ChatDB.insert_chat_data(1, 'Hello', 'Hi there!')
|
||||||
|
Chat data inserted successfully.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
conn = mysql.connector.connect(**db_config)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO ChatDB.Chat_data (chat_id, user, assistant)
|
||||||
|
VALUES (%s, %s, %s)
|
||||||
|
""",
|
||||||
|
(chat_id, user_message, assistant_message),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
print("Chat data inserted successfully.")
|
||||||
|
except mysql.connector.Error as err:
|
||||||
|
print(f"Error: {err}")
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
class Chatbot:
|
||||||
|
def __init__(self, api_service: str):
|
||||||
|
self.llm_service = LLMService(api_service)
|
||||||
|
self.conversation_history = []
|
||||||
|
self.chat_id_pk = None
|
||||||
|
self.start_time = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
|
||||||
|
def chat_session(self) -> None:
|
||||||
|
"""
|
||||||
|
Start a chatbot session, allowing the user to interact with the LLM.
|
||||||
|
Saves conversation history in the database and ends the session on "/stop" command.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> chatbot = Chatbot(api_service="Groq")
|
||||||
|
>>> chatbot.chat_session() # This will be mocked in the tests
|
||||||
|
Welcome to the chatbot! Type '/stop' to end the conversation.
|
||||||
|
"""
|
||||||
|
print("Welcome to the chatbot! Type '/stop' to end the conversation.")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("\nYou: ").strip()
|
||||||
|
self.conversation_history.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
|
if self.chat_id_pk is None:
|
||||||
|
if user_input.lower() == "/stop":
|
||||||
|
break
|
||||||
|
bot_response = self.llm_service.generate_response(self.conversation_history)
|
||||||
|
self.conversation_history.append({"role": "assistant", "content": bot_response})
|
||||||
|
|
||||||
|
is_stream = 1 # New conversation
|
||||||
|
self.chat_id_pk = ChatDB.insert_chat_history(self.start_time, is_stream) # Return the chat_id
|
||||||
|
if self.chat_id_pk:
|
||||||
|
ChatDB.insert_chat_data(self.chat_id_pk, user_input, bot_response)
|
||||||
|
else:
|
||||||
|
if user_input.lower() == "/stop":
|
||||||
|
is_stream = 2 # End of conversation
|
||||||
|
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
ChatDB.insert_chat_history(current_time, is_stream)
|
||||||
|
break
|
||||||
|
|
||||||
|
bot_response = self.llm_service.generate_response(self.conversation_history)
|
||||||
|
self.conversation_history.append({"role": "assistant", "content": bot_response})
|
||||||
|
|
||||||
|
is_stream = 0 # Continuation of conversation
|
||||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
current_time = datetime.datetime.now(datetime.timezone.utc)
|
||||||
insert_chat_history(current_time, is_stream)
|
ChatDB.insert_chat_history(current_time, is_stream)
|
||||||
break
|
ChatDB.insert_chat_data(self.chat_id_pk, user_input, bot_response)
|
||||||
|
|
||||||
bot_response = generate_llm_response(conversation_history, api_service)
|
if len(self.conversation_history) > 1000:
|
||||||
conversation_history.append({"role": "assistant", "content": bot_response})
|
self.conversation_history = self.conversation_history[-3:]
|
||||||
|
|
||||||
is_stream = 0 # Continuation of conversation
|
# Test cases for Chatbot
|
||||||
current_time = datetime.datetime.now(datetime.timezone.utc)
|
class TestChatbot(unittest.TestCase):
|
||||||
insert_chat_history(current_time, is_stream)
|
|
||||||
insert_chat_data(chat_id_pk, user_input, bot_response)
|
|
||||||
|
|
||||||
if len(conversation_history) > 1000:
|
@patch('builtins.input', side_effect=["Hello", "/stop"])
|
||||||
conversation_history = conversation_history[-3:]
|
@patch('sys.stdout', new_callable=StringIO)
|
||||||
|
def test_chat_session(self, mock_stdout, mock_input):
|
||||||
|
"""
|
||||||
|
Test the chat_session method for expected welcome message.
|
||||||
|
"""
|
||||||
|
chatbot = Chatbot(api_service="Groq")
|
||||||
|
chatbot.chat_session()
|
||||||
|
|
||||||
|
# Check for the welcome message in the output
|
||||||
|
output = mock_stdout.getvalue().strip().splitlines()
|
||||||
|
self.assertIn("Welcome to the chatbot! Type '/stop' to end the conversation.", output)
|
||||||
|
self.assertTrue(any("Chat history inserted successfully." in line for line in output))
|
||||||
|
self.assertTrue(any("Chat data inserted successfully." in line for line in output))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
# starting a chat session
|
#
|
||||||
create_tables()
|
ChatDB.create_tables()
|
||||||
chat_session()
|
unittest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user