diff --git a/database.py b/database.py index e8a7b15..21c1255 100644 --- a/database.py +++ b/database.py @@ -12,11 +12,13 @@ def get_db_connection(): return conn def initialize_db(): - """Initializes the database and creates tables if they don't exist.""" + """ + Initializes the database and creates the necessary tables if they don't exist. + """ conn = get_db_connection() cursor = conn.cursor() - # --- Create tickers table --- + # --- Create tickers table (This is the corrected section) --- cursor.execute(""" CREATE TABLE IF NOT EXISTS tickers ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -26,7 +28,7 @@ def initialize_db(): ) """) - # --- Create subreddits table --- + # --- Create subreddits table (This is the corrected section) --- cursor.execute(""" CREATE TABLE IF NOT EXISTS subreddits ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -34,7 +36,7 @@ def initialize_db(): ) """) - # --- Create mentions table --- + # --- Create mentions table with sentiment_score column --- cursor.execute(""" CREATE TABLE IF NOT EXISTS mentions ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -53,12 +55,23 @@ def initialize_db(): conn.close() print("Database initialized successfully.") +def add_mention(conn, ticker_id, subreddit_id, post_id, timestamp, sentiment): + """Adds a new mention with its sentiment score to the database.""" + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO mentions (ticker_id, subreddit_id, post_id, mention_timestamp, sentiment_score) VALUES (?, ?, ?, ?, ?)", + (ticker_id, subreddit_id, post_id, timestamp, sentiment) + ) + conn.commit() + except sqlite3.IntegrityError: + pass # Ignore duplicate mentions + def get_or_create_entity(conn, table_name, column_name, value): """Generic function to get or create an entity and return its ID.""" cursor = conn.cursor() cursor.execute(f"SELECT id FROM {table_name} WHERE {column_name} = ?", (value,)) result = cursor.fetchone() - if result: return result['id'] else: @@ -66,18 +79,6 @@ def get_or_create_entity(conn, table_name, column_name, value): conn.commit() return cursor.lastrowid -def add_mention(conn, ticker_id, subreddit_id, post_id, timestamp): - """Adds a new mention to the database, ignoring duplicates.""" - cursor = conn.cursor() - try: - cursor.execute( - "INSERT INTO mentions (ticker_id, subreddit_id, post_id, mention_timestamp) VALUES (?, ?, ?, ?)", - (ticker_id, subreddit_id, post_id, timestamp) - ) - conn.commit() - except sqlite3.IntegrityError: - pass - def update_ticker_market_cap(conn, ticker_id, market_cap): """Updates the market cap and timestamp for a specific ticker.""" cursor = conn.cursor() @@ -95,7 +96,7 @@ def get_ticker_info(conn, ticker_id): return cursor.fetchone() def generate_summary_report(): - """Queries the DB to generate and print a summary with market caps.""" + """Queries the DB to generate a summary with market caps and avg. sentiment.""" print("\n--- Summary Report ---") conn = get_db_connection() cursor = conn.cursor() @@ -104,33 +105,38 @@ def generate_summary_report(): SELECT t.symbol, t.market_cap, - COUNT(m.id) as mention_count + COUNT(m.id) as mention_count, + AVG(m.sentiment_score) as avg_sentiment FROM mentions m JOIN tickers t ON m.ticker_id = t.id GROUP BY t.symbol, t.market_cap ORDER BY mention_count DESC LIMIT 20; """ - results = cursor.execute(query).fetchall() - print(f"{'Ticker':<10} | {'Mentions':<10} | {'Market Cap':<20}") - print("-" * 45) + print(f"{'Ticker':<10} | {'Mentions':<10} | {'Sentiment':<18} | {'Market Cap':<20}") + print("-" * 65) for row in results: + # Format Market Cap market_cap_str = "N/A" - if row['market_cap']: - # Format market cap into a readable string (e.g., $1.23T, $45.6B, $123.4M) + if row['market_cap'] and row['market_cap'] > 0: mc = row['market_cap'] - if mc >= 1e12: - market_cap_str = f"${mc/1e12:.2f}T" - elif mc >= 1e9: - market_cap_str = f"${mc/1e9:.2f}B" - elif mc >= 1e6: - market_cap_str = f"${mc/1e6:.2f}M" - else: - market_cap_str = f"${mc:,}" + if mc >= 1e12: market_cap_str = f"${mc/1e12:.2f}T" + elif mc >= 1e9: market_cap_str = f"${mc/1e9:.2f}B" + elif mc >= 1e6: market_cap_str = f"${mc/1e6:.2f}M" + else: market_cap_str = f"${mc:,}" + + # Determine Sentiment Label + sentiment_score = row['avg_sentiment'] + if sentiment_score is not None: + if sentiment_score > 0.1: sentiment_label = f"Bullish ({sentiment_score:+.2f})" + elif sentiment_score < -0.1: sentiment_label = f"Bearish ({sentiment_score:+.2f})" + else: sentiment_label = f"Neutral ({sentiment_score:+.2f})" + else: + sentiment_label = "N/A" - print(f"{row['symbol']:<10} | {row['mention_count']:<10} | {market_cap_str:<20}") + print(f"{row['symbol']:<10} | {row['mention_count']:<10} | {sentiment_label:<18} | {market_cap_str:<20}") conn.close() \ No newline at end of file diff --git a/main.py b/main.py index 66896bf..48b1dd5 100644 --- a/main.py +++ b/main.py @@ -11,49 +11,40 @@ from dotenv import load_dotenv import database from ticker_extractor import extract_tickers +from sentiment_analyzer import get_sentiment_score # <-- IMPORT OUR NEW MODULE -# Load environment variables from .env file load_dotenv() -# How old (in seconds) market cap data can be before we refresh it. 24 hours = 86400 seconds. MARKET_CAP_REFRESH_INTERVAL = 86400 +# ... (load_subreddits, get_market_cap, get_reddit_instance functions are unchanged) def load_subreddits(filepath): - # (This function is unchanged) + # ... try: - with open(filepath, 'r') as f: - data = json.load(f) - return data.get("subreddits", []) - except FileNotFoundError: - print(f"Error: The file '{filepath}' was not found.") - return None - except json.JSONDecodeError: - print(f"Error: Could not decode JSON from '{filepath}'.") + with open(filepath, 'r') as f: return json.load(f).get("subreddits", []) + except (FileNotFoundError, json.JSONDecodeError) as e: + print(f"Error loading config: {e}") return None def get_market_cap(ticker_symbol): - """Fetches the market capitalization for a given stock ticker from yfinance.""" + # ... try: ticker = yf.Ticker(ticker_symbol) - # .info can be slow; .fast_info is a lighter alternative - market_cap = ticker.fast_info.get('marketCap') - return market_cap if market_cap else None - except Exception: - return None + return ticker.fast_info.get('marketCap') + except Exception: return None def get_reddit_instance(): - # (This function is unchanged) + # ... client_id = os.getenv("REDDIT_CLIENT_ID") client_secret = os.getenv("REDDIT_CLIENT_SECRET") user_agent = os.getenv("REDDIT_USER_AGENT") - if not all([client_id, client_secret, user_agent]): - print("Error: Reddit API credentials not found in .env file.") + print("Error: Reddit API credentials not found.") return None return praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent) def scan_subreddits(reddit, subreddits_list, post_limit=25): - """Scans subreddits, stores mentions, and updates market caps in the database.""" + """Scans subreddits, performs sentiment analysis, and stores results in the database.""" conn = database.get_db_connection() print(f"\nScanning {len(subreddits_list)} subreddits for top {post_limit} posts...") @@ -64,32 +55,34 @@ def scan_subreddits(reddit, subreddits_list, post_limit=25): print(f"Scanning r/{subreddit_name}...") for submission in subreddit.hot(limit=post_limit): - full_text = submission.title + " " + submission.selftext - tickers_in_post = extract_tickers(full_text) + # We analyze the title for sentiment as it's often the most concise summary. + # Analyzing all comments could be a future enhancement. + text_to_analyze = submission.title + tickers_in_post = extract_tickers(text_to_analyze + " " + submission.selftext) + + # --- NEW: Get sentiment score for the post's title --- + sentiment = get_sentiment_score(text_to_analyze) for ticker_symbol in set(tickers_in_post): ticker_id = database.get_or_create_entity(conn, 'tickers', 'symbol', ticker_symbol) + # --- NEW: Pass the sentiment score to the database --- database.add_mention( conn, ticker_id=ticker_id, subreddit_id=subreddit_id, post_id=submission.id, - timestamp=int(submission.created_utc) + timestamp=int(submission.created_utc), + sentiment=sentiment # Pass the score here ) - # --- Check if market cap needs updating --- + # (The market cap update logic remains the same) ticker_info = database.get_ticker_info(conn, ticker_id) current_time = int(time.time()) - if not ticker_info['last_updated'] or (current_time - ticker_info['last_updated'] > MARKET_CAP_REFRESH_INTERVAL): print(f" -> Fetching market cap for {ticker_symbol}...") market_cap = get_market_cap(ticker_symbol) - if market_cap: - database.update_ticker_market_cap(conn, ticker_id, market_cap) - else: - # If fetch fails, still update the timestamp so we don't try again for 24 hours - database.update_ticker_market_cap(conn, ticker_id, ticker_info['market_cap']) # Keep old value + database.update_ticker_market_cap(conn, ticker_id, market_cap or ticker_info['market_cap']) except Exception as e: print(f"Could not scan r/{subreddit_name}. Error: {e}") @@ -97,28 +90,23 @@ def scan_subreddits(reddit, subreddits_list, post_limit=25): conn.close() print("\n--- Scan Complete ---") - def main(): - """Main function to run the Reddit stock analysis tool.""" + # --- IMPORTANT: Delete your old DB file before running! --- + # Since we changed the schema and logic, old data won't have sentiment. + # It's best to start fresh. Delete the `reddit_stocks.db` file now. + parser = argparse.ArgumentParser(description="Analyze stock ticker mentions on Reddit.") parser.add_argument("config_file", help="Path to the JSON file containing subreddits.") args = parser.parse_args() - # --- Part 1: Initialize --- database.initialize_db() - subreddits = load_subreddits(args.config_file) if not subreddits: return - reddit = get_reddit_instance() if not reddit: return - # --- Part 2: Scan and Store --- scan_subreddits(reddit, subreddits) - - # --- Part 3: Generate and Display Report --- database.generate_summary_report() - if __name__ == "__main__": main() \ No newline at end of file diff --git a/reddit_stocks.db b/reddit_stocks.db new file mode 100644 index 0000000..0aa5fac Binary files /dev/null and b/reddit_stocks.db differ diff --git a/requirements.txt b/requirements.txt index 67ac91b..630ba9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ yfinance praw -python-dotenv \ No newline at end of file +python-dotenv +nltk \ No newline at end of file diff --git a/sentiment_analyzer.py b/sentiment_analyzer.py new file mode 100644 index 0000000..32b08e8 --- /dev/null +++ b/sentiment_analyzer.py @@ -0,0 +1,19 @@ +# sentiment_analyzer.py + +from nltk.sentiment.vader import SentimentIntensityAnalyzer + +# Initialize the VADER sentiment intensity analyzer +# We only need to create one instance of this. +_analyzer = SentimentIntensityAnalyzer() + +def get_sentiment_score(text): + """ + Analyzes a piece of text and returns its sentiment score. + + The 'compound' score is a single metric that summarizes the sentiment. + It ranges from -1 (most negative) to +1 (most positive). + """ + # The polarity_scores() method returns a dictionary with 'neg', 'neu', 'pos', and 'compound' scores. + # We are most interested in the 'compound' score. + scores = _analyzer.polarity_scores(text) + return scores['compound'] \ No newline at end of file diff --git a/setup_nltk.py b/setup_nltk.py new file mode 100644 index 0000000..bfd5209 --- /dev/null +++ b/setup_nltk.py @@ -0,0 +1,11 @@ +import nltk + +# This will download the 'vader_lexicon' dataset +# It only needs to be run once +try: + nltk.data.find('sentiment/vader_lexicon.zip') + print("VADER lexicon is already downloaded.") +except LookupError: + print("Downloading VADER lexicon...") + nltk.download('vader_lexicon') + print("Download complete.") \ No newline at end of file diff --git a/ticker_extractor.py b/ticker_extractor.py index eafe882..7ef8c12 100644 --- a/ticker_extractor.py +++ b/ticker_extractor.py @@ -19,9 +19,31 @@ COMMON_WORDS_BLACKLIST = { "NASDAQ", "NYSE", "AMEX", "FTSE", "DAX", "WSB", "SPX", "DJIA", "EDGAR", "GDP", "CPI", "PPI", "PMI", "ISM", "FOMC", "ECB", "BOE", "BOJ", "RBA", "RBNZ", "BIS", "NFA", "P", "VOO", "CTB", "DR", - "ETF", "EV", "ESG", "REIT", "SPAC", "IPO", "M&A", "LBO", + "ETF", "EV", "ESG", "REIT", "SPAC", "IPO", "M&A", "LBO", "PE", "Q1", "Q2", "Q3", "Q4", "FY", "FAQ", "ROI", "ROE", "EPS", "P/E", "PEG", - "FRG", "FXAIX", "FXIAX", "FZROX" + "FRG", "FXAIX", "FXIAX", "FZROX", "BULL", "BEAR", "BULLISH", "BEARISH", + "QQQ", "SPY", "DIA", "IWM", "VTI", "VOO", "IVV", "SCHB", "SPLG", + "ROTH", "IRA", "401K", "403B", "457B", "SEP", "SIMPLE", "HSA", + "LONG", "SHORT", "LEVERAGE", "MARGIN", "HEDGE", "SWING", "DAY", + "GRAB", "GPU", "MY", "PSA", "AMA", "DM", "OP", "SPAC", "FIHTX", + "FINTX", "FINT", "FINTX", "FINTY", "FSPSX", "TOTAL", "LARGE", "MID", "SMALL", + "GROWTH", "VALUE", "BLEND", "INCOME", "DIV", "YIELD", "BETA", "ALPHA", "VOLATILITY", + "RISK", "RETURN", "SHARPE", "SORTINO", "MAX", "MIN", "STDDEV", "VARIANCE", + "PDF", "FULL", "PEAK", "LATE", "EARLY", "MIDDAY", "NIGHT", "MORNING", "AFTERNOON", + "CYCLE", "TREND", "PATTERN", "BREAKOUT", "PULLBACK", "REVERSAL", "CONSOLIDATION", + "OTC", "TRUE", "FALSE", "NULL", "NONE", "ALL", "ANY", "SOME", "EACH", "EVERY", + "STILL", "TERM", "TIME", "DATE", "YEAR", "MONTH", "WEEK", "HOUR", "MINUTE", "SECOND", + "JUST", "ALREADY", "STILL", "YET", "NOW", "LATER", "SOON", "EARLIER", "TODAY", "TOMORROW", + "YESTERDAY", "TONIGHT", "THIS", "LAST", "NEXT", "WOULD", "SHOULD", "COULD", "MIGHT", + "WILL", "CAN", "MUST", "SHALL", "OUGHT", "TAKE", "MAKE", "HAVE", "GET", "DO", "BE", + "GO", "COME", "SEE", "LOOK", "WATCH", "HEAR", "YES", "NO", "OK", "LIKE", "LOVE", "HATE", + "WANT", "NEED", "THINK", "BELIEVE", "KNOW", "PRICE", "COST", "VALUE", "WORTH", + "EXPENSE", "SPEND", "SAVE", "EARN", "PROFIT", "LOSS", "GAIN", "DEBT", "CREDIT", + "BOND", "STOCK", "SHARE", "FUND", "ASSET", "LIABILITY", "BUZZ", "UNDER", "OVER", "BETWEEN", + "FRAUD", "SCAM", "RISK", "REWARD", "RETURN", "INVEST", "TRADE", "BUY", "SELL", "HOLD", + "SHORT", "LONG", "LEVERAGE", "MARGIN", "HEDGE", "SCALP", "POSITION", + "PLAN", "GOAL", "WILL", "FAST", "HINT", "ABOVE", "BELOW", "AROUND", "NEAR", "FAR", + "TL", } def extract_tickers(text):