News:

Welcome to RetroCoders Community

Main Menu

Seinfeld Markov Chain Chatbot

Started by ron77, Apr 03, 2025, 04:14 PM

Previous topic - Next topic

ron77

#!/usr/bin/env python3
"""
Seinfeld Markov Chain Chatbot
A Python implementation of a Markov chain chatbot trained on Seinfeld dialogue.
"""

import random
import re
import pickle
import os
from collections import defaultdict, Counter
import string

class MarkovModel:
    """Markov chain model for generating text based on trained data."""
    
    def __init__(self):
        self.tokens = []  # List of all tokens
        self.token_to_idx = {}  # Mapping from token to its index
        self.transitions = []  # List of transitions for each token
        self.model_trained = False
    
    def add_token(self, token):
        """Add a token to the model if it doesn't exist and return its index."""
        if token in self.token_to_idx:
            return self.token_to_idx[token]
        
        token_idx = len(self.tokens)
        self.tokens.append(token)
        self.token_to_idx[token] = token_idx
        self.transitions.append([])
        return token_idx
    
    def add_transition(self, from_token_idx, to_token_idx):
        """Add a transition from one token to another."""
        self.transitions[from_token_idx].append(to_token_idx)
    
    def get_next_token(self, token_idx):
        """Get a random next token based on transitions."""
        if not self.transitions[token_idx]:
            return -1
        
        return random.choice(self.transitions[token_idx])
    
    def find_token(self, search_token):
        """Find a token in the model with various search strategies."""
        # Strategy 1: Direct lookup
        if search_token in self.token_to_idx:
            return self.token_to_idx[search_token]
        
        # Strategy 2: Case insensitive lookup
        search_token_lower = search_token.lower()
        for token, idx in self.token_to_idx.items():
            if token.lower() == search_token_lower:
                return idx
        
        # Strategy 3: Substring match
        for token, idx in self.token_to_idx.items():
            if search_token_lower in token.lower():
                return idx
        
        # Strategy 4: First word match for multi-word tokens
        if ' ' in search_token:
            first_word = search_token.split()[0].lower()
            for token, idx in self.token_to_idx.items():
                if token.lower().startswith(first_word):
                    return idx
        
        # Strategy 5: Return random token with many transitions
        candidates = []
        for idx, transitions in enumerate(self.transitions):
            if len(transitions) > 5:  # Token has reasonable number of transitions
                token = self.tokens[idx]
                # Skip tokens with script artifacts
                if '(' not in token and ')' not in token and ':' not in token:
                    candidates.append((idx, len(transitions)))
        
        if candidates:
            # Sort by transition count (descending) and take top 10
            candidates.sort(key=lambda x: x[1], reverse=True)
            top_candidates = candidates[:10]
            return random.choice(top_candidates)[0]
        
        # Fallback to random token
        return random.randint(0, len(self.tokens) - 1) if self.tokens else -1


class SeinfeldChatbot:
    """Chatbot that responds with Seinfeld-style dialogue using a Markov chain."""
    
    def __init__(self):
        self.model = MarkovModel()
        self.seinfeld_characters = ["JERRY", "GEORGE", "ELAINE", "KRAMER", "NEWMAN"]
        self.min_response_length = 5  # Minimum words in a response
        self.max_response_attempts = 20  # Maximum attempts to generate a response
        self.stopwords = {
            "a", "an", "the", "and", "but", "or", "for", "nor", "on", "at", "to", "from",
            "by", "with", "in", "out", "over", "under", "again", "further", "then",
            "once", "here", "there", "when", "where", "why", "how", "all", "any",
            "both", "each", "few", "more", "most", "other", "some", "such", "no",
            "not", "only", "own", "same", "so", "than", "too", "very", "s", "t",
            "can", "will", "just", "don", "should", "now", "d", "ll", "m", "o",
            "re", "ve", "y", "ain", "aren", "couldn", "didn", "doesn", "hadn",
            "hasn", "haven", "isn", "ma", "mightn", "mustn", "needn", "shan",
            "shouldn", "wasn", "weren", "won", "wouldn", "i", "me", "my", "myself",
            "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself",
            "yourselves", "he", "him", "his", "himself", "she", "her", "hers",
            "herself", "it", "its", "itself", "they", "them", "their", "theirs"
        }
        self.debug_mode = False
        self.used_seeds = set()  # Track used seed phrases to avoid repetition
    
    def clean_text(self, text):
        """Clean up text by removing script artifacts and extra whitespace."""
        # Remove stage directions
        text = re.sub(r'\([^)]*\)', '', text)
        
        # Remove character names at start of lines (e.g., "JERRY: ")
        text = re.sub(r'^[A-Z]+:', '', text)
        
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    def tokenize(self, text):
        """Split text into tokens."""
        return text.split()
    
    def remove_punctuation(self, word):
        """Remove punctuation from a word."""
        return word.translate(str.maketrans('', '', string.punctuation))
    
    def train_on_seinfeld_data(self, filename):
        """Train the model on Seinfeld dialogue."""
        print(f"Training on Seinfeld data from {filename}...")
        
        try:
            with open(filename, 'r', encoding='utf-8', errors='ignore') as f:
                content = f.read()
            
            # Break into lines and process line by line
            lines = content.split('\n')
            dialog_lines = []
            
            for line in lines:
                line = line.strip()
                
                # Skip empty lines and scene markers
                if not line or line.startswith('%'):
                    continue
                
                # Check if this is a character's dialogue
                for character in self.seinfeld_characters:
                    if line.startswith(character + ':'):
                        # Extract the dialogue part (after the character name)
                        dialogue = line[len(character) + 1:].strip()
                        dialogue = self.clean_text(dialogue)
                        if dialogue:
                            dialog_lines.append(dialogue)
                        break
            
            # Process the dialogue for training
            token_count = 0
            prev_token_idx = -1
            
            for line in dialog_lines:
                tokens = self.tokenize(line)
                
                for token in tokens:
                    token = token.strip()
                    if not token:
                        continue
                    
                    token_idx = self.model.add_token(token)
                    token_count += 1
                    
                    # Add transition from previous token
                    if prev_token_idx >= 0:
                        self.model.add_transition(prev_token_idx, token_idx)
                    
                    prev_token_idx = token_idx
            
            self.model.model_trained = True
            print(f"Training complete! Processed {token_count} tokens from {len(dialog_lines)} dialogue lines.")
            return True
            
        except Exception as e:
            print(f"Error training model: {e}")
            return False
    
    def save_model(self, filename):
        """Save the trained model to a file."""
        if not self.model.model_trained:
            print("No trained model to save!")
            return False
        
        try:
            with open(filename, 'wb') as f:
                pickle.dump(self.model, f)
            print(f"Model saved to {filename}")
            return True
        except Exception as e:
            print(f"Error saving model: {e}")
            return False
    
    def load_model(self, filename):
        """Load a previously trained model from a file."""
        try:
            with open(filename, 'rb') as f:
                self.model = pickle.load(f)
            print(f"Model loaded from {filename}")
            return True
        except Exception as e:
            print(f"Error loading model: {e}")
            return False
    
    def find_good_seeds(self, user_input):
        """Find good seed phrases from user input."""
        words = user_input.split()
        good_seeds = []
        
        # Remove stopwords
        important_words = [w for w in words if w.lower() not in self.stopwords 
                          and len(w) > 2  # Skip very short words
                          and not w.isdigit()]  # Skip numbers
        
        # Try various seed strategies
        
        # 1. Two-word combinations of important words
        if len(important_words) >= 2:
            for i in range(len(important_words) - 1):
                seed = f"{important_words[i]} {important_words[i+1]}"
                token_idx = self.model.find_token(seed)
                if token_idx >= 0 and seed not in self.used_seeds:
                    good_seeds.append((seed, 3))  # Higher score for two-word matches
        
        # 2. Individual important words
        for word in important_words:
            token_idx = self.model.find_token(word)
            if token_idx >= 0 and word not in self.used_seeds:
                good_seeds.append((word, 2))
        
        # 3. Original phrases from input (up to 3 words)
        for i in range(len(words)):
            for j in range(1, min(4, len(words) - i + 1)):
                phrase = " ".join(words[i:i+j])
                token_idx = self.model.find_token(phrase)
                if token_idx >= 0 and phrase not in self.used_seeds:
                    good_seeds.append((phrase, j))  # Score based on length
        
        # Sort by score (descending)
        good_seeds.sort(key=lambda x: x[1], reverse=True)
        
        # Return just the seeds (without scores)
        return [seed for seed, _ in good_seeds]
    
    def generate_sentence(self, seed_phrase):
        """Generate a sentence starting with the given seed phrase."""
        if not seed_phrase or not self.model.model_trained:
            return ""
        
        max_length = 50  # Maximum number of tokens in a sentence
        max_attempts = 30  # Maximum attempts to generate a valid token
        
        # Start with the seed phrase
        token_idx = self.model.find_token(seed_phrase)
        if token_idx < 0:
            return ""
        
        current_token = self.model.tokens[token_idx]
        sentence = current_token
        used_tokens = {current_token}
        word_count = len(seed_phrase.split())
        
        sentence_ended = False
        attempts = 0
        
        while word_count < max_length and attempts < max_attempts and not sentence_ended:
            # Find the next token
            next_token_idx = self.model.get_next_token(token_idx)
            
            if next_token_idx >= 0 and next_token_idx < len(self.model.tokens):
                next_token = self.model.tokens[next_token_idx]
                
                # Check for repetition
                if next_token not in used_tokens:
                    sentence += " " + next_token
                    used_tokens.add(next_token)
                    word_count += 1
                    token_idx = next_token_idx
                    
                    # Check if this ended a sentence
                    if next_token and next_token[-1] in ['.', '!', '?']:
                        sentence_ended = True
                    
                    # If we've reached a reasonable length, consider stopping
                    if word_count >= 8:  # At least 8 words for a decent sentence
                        break
                else:
                    attempts += 1
            else:
                attempts += 1
        
        # Ensure sentence ends with proper punctuation
        if sentence and sentence[-1] not in ['.', '!', '?']:
            sentence += "."
        
        # Final cleanup
        sentence = sentence.strip()
        sentence = re.sub(r'\s+', ' ', sentence)  # Remove extra spaces
        
        # Filter out sentences that are too short
        if len(sentence.split()) < 5:
            return ""
        
        # Capitalize first letter
        if sentence:
            sentence = sentence[0].upper() + sentence[1:]
        
        return sentence
    
    def generate_fallback_response(self):
        """Generate a fallback Seinfeld-style response when normal generation fails."""
        responses = [
            "What's the deal with that?",
            "Not that there's anything wrong with that.",
            "These pretzels are making me thirsty!",
            "I'm out there, Jerry, and I'm loving every minute of it!",
            "No soup for you!",
            "Serenity now!",
            "I don't wanna be a pirate!",
            "It's not a lie if you believe it.",
            "You know, we're living in a society!",
            "I'm speechless. I have no speech.",
            "You want a piece of me? YOU GOT IT!",
            "Hello, Newman.",
            "That's a shame.",
            "I've yada yada'd sex.",
            "Maybe the dingo ate your baby!",
            "I choose not to run!",
            "It's not you, it's me.",
            "You can stuff your sorries in a sack, mister!",
            "I'm a joke maker. Tell him, Jerry.",
            "And you want to be my latex salesman..."
        ]
        return random.choice(responses)
    
    def generate_response(self, user_input):
        """Generate a response to user input."""
        if not user_input or not self.model.model_trained:
            return "The chatbot hasn't been trained yet!"
        
        # Clean and normalize input
        user_input = self.clean_text(user_input)
        
        # Find good seed phrases from the input
        seed_phrases = self.find_good_seeds(user_input)
        
        # Try to generate responses with each seed phrase
        responses = []
        
        for seed in seed_phrases[:5]:  # Try top 5 seeds
            for _ in range(self.max_response_attempts // len(seed_phrases)):
                response = self.generate_sentence(seed)
                if response and len(response.split()) >= self.min_response_length:
                    responses.append(response)
            
            # Mark this seed as used to prevent repetition
            self.used_seeds.add(seed)
            
            # Keep used_seeds from growing too large
            if len(self.used_seeds) > 100:
                self.used_seeds = set(list(self.used_seeds)[-50:])
        
        # If we generated valid responses, choose the best one
        if responses:
            # Prioritize longer responses
            responses.sort(key=lambda x: len(x.split()), reverse=True)
            
            # Randomly select one of the top responses
            return random.choice(responses[:3])
        
        # If all else fails, return a fallback response
        return self.generate_fallback_response()
    
    def get_character_name(self):
        """Get a random Seinfeld character name for the response."""
        top_characters = ["Jerry", "George"]
        return random.choice(top_characters)


def main():
    """Main function to run the chatbot."""
    chatbot = SeinfeldChatbot()
    
    print("=" * 60)
    print("  Seinfeld Markov Chain Chatbot")
    print("=" * 60)
    
    while True:
        print("\nOptions:")
        print("1. Train on Seinfeld dialogue")
        print("2. Load existing model")
        print("3. Chat with the bot")
        print("4. Save the current model")
        print("5. Exit")
        
        choice = input("\nEnter your choice (1-5): ")
        
        if choice == '1':
            filename = input("Enter the Seinfeld dialogue file path (default: seinfeld_complete_dialogue.txt): ")
            if not filename:
                filename = "seinfeld_complete_dialogue.txt"
            
            if not os.path.exists(filename):
                print(f"File {filename} not found!")
                continue
            
            chatbot.train_on_seinfeld_data(filename)
            
        elif choice == '2':
            filename = input("Enter the model file path (default: seinfeld_model.pkl): ")
            if not filename:
                filename = "seinfeld_model.pkl"
            
            if not os.path.exists(filename):
                print(f"File {filename} not found!")
                continue
            
            chatbot.load_model(filename)
            
        elif choice == '3':
            if not chatbot.model.model_trained:
                print("Please train or load a model first!")
                continue
            
            print("\n===== Seinfeld Chat Mode =====")
            print("(Type 'exit' to return to the main menu)")
            
            while True:
                user_input = input("\nYou: ")
                
                if user_input.lower() in ['exit', 'quit', 'bye']:
                    break
                
                character = chatbot.get_character_name()
                response = chatbot.generate_response(user_input)
                print(f"\n{character}: {response}")
                
        elif choice == '4':
            if not chatbot.model.model_trained:
                print("No trained model to save!")
                continue
            
            filename = input("Enter filename to save the model (default: seinfeld_model.pkl): ")
            if not filename:
                filename = "seinfeld_model.pkl"
            
            chatbot.save_model(filename)
            
        elif choice == '5':
            print("Goodbye!")
            break
            
        else:
            print("Invalid choice, please try again.")


if __name__ == "__main__":
    main()

You cannot view this attachment.