diff --git a/neural_network/chatbot/chat_db.py b/neural_network/chatbot/chat_db.py index e888bc0ec..c1758ed0a 100644 --- a/neural_network/chatbot/chat_db.py +++ b/neural_network/chatbot/chat_db.py @@ -22,6 +22,7 @@ db_config = { "database": os.environ.get("DB_NAME"), } + class LLMService: def __init__(self, api_service: str): self.api_service = api_service @@ -62,7 +63,7 @@ class LLMService: stop=["<|eot_id|>", "<|eom_id|>"], stream=False, ) - + return response.choices[0].message.content @@ -176,7 +177,7 @@ class ChatDB: cursor = conn.cursor() cursor.execute( """ - SELECT chat_id FROM ChatDB.Chat_history + SELECT chat_id FROM ChatDB.Chat_history ORDER BY chat_id DESC LIMIT 1 """ ) @@ -190,10 +191,12 @@ class ChatDB: conn.close() @staticmethod - def insert_chat_data(chat_id: int, user_message: str, assistant_message: str) -> None: + def insert_chat_data( + chat_id: int, user_message: str, assistant_message: str + ) -> None: """ Insert a new row into the ChatDB.Chat_data table. - + Example: >>> ChatDB.insert_chat_data(1, 'Hello', 'Hi there!') Chat data inserted successfully. @@ -228,7 +231,7 @@ class Chatbot: """ Start a chatbot session, allowing the user to interact with the LLM. Saves conversation history in the database and ends the session on "/stop" command. - + Example: >>> chatbot = Chatbot(api_service="Groq") >>> chatbot.chat_session() # This will be mocked in the tests @@ -243,11 +246,17 @@ class Chatbot: if self.chat_id_pk is None: if user_input.lower() == "/stop": break - bot_response = self.llm_service.generate_response(self.conversation_history) - self.conversation_history.append({"role": "assistant", "content": bot_response}) + bot_response = self.llm_service.generate_response( + self.conversation_history + ) + self.conversation_history.append( + {"role": "assistant", "content": bot_response} + ) is_stream = 1 # New conversation - self.chat_id_pk = ChatDB.insert_chat_history(self.start_time, is_stream) # Return the chat_id + self.chat_id_pk = ChatDB.insert_chat_history( + self.start_time, is_stream + ) # Return the chat_id if self.chat_id_pk: ChatDB.insert_chat_data(self.chat_id_pk, user_input, bot_response) else: @@ -257,8 +266,12 @@ class Chatbot: ChatDB.insert_chat_history(current_time, is_stream) break - bot_response = self.llm_service.generate_response(self.conversation_history) - self.conversation_history.append({"role": "assistant", "content": bot_response}) + bot_response = self.llm_service.generate_response( + self.conversation_history + ) + self.conversation_history.append( + {"role": "assistant", "content": bot_response} + ) is_stream = 0 # Continuation of conversation current_time = datetime.datetime.now(datetime.timezone.utc) @@ -268,23 +281,30 @@ class Chatbot: if len(self.conversation_history) > 1000: self.conversation_history = self.conversation_history[-3:] + # Test cases for Chatbot class TestChatbot(unittest.TestCase): - - @patch('builtins.input', side_effect=["Hello", "/stop"]) - @patch('sys.stdout', new_callable=StringIO) + @patch("builtins.input", side_effect=["Hello", "/stop"]) + @patch("sys.stdout", new_callable=StringIO) def test_chat_session(self, mock_stdout, mock_input): """ Test the chat_session method for expected welcome message. """ chatbot = Chatbot(api_service="Groq") chatbot.chat_session() - + # Check for the welcome message in the output output = mock_stdout.getvalue().strip().splitlines() - self.assertIn("Welcome to the chatbot! Type '/stop' to end the conversation.", output) - self.assertTrue(any("Chat history inserted successfully." in line for line in output)) - self.assertTrue(any("Chat data inserted successfully." in line for line in output)) + self.assertIn( + "Welcome to the chatbot! Type '/stop' to end the conversation.", output + ) + self.assertTrue( + any("Chat history inserted successfully." in line for line in output) + ) + self.assertTrue( + any("Chat data inserted successfully." in line for line in output) + ) + if __name__ == "__main__": #