From e80978681a3dc3e65f82cb890c4e89bf9b1c5449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A5l-Kristian=20Hamre?= Date: Mon, 21 Jul 2025 12:35:18 +0200 Subject: [PATCH] Initial database setup. --- database.py | 136 ++++++++++++++++++++++++++++++++++++++++++++ main.py | 107 +++++++++++++++++----------------- ticker_extractor.py | 10 +++- 3 files changed, 200 insertions(+), 53 deletions(-) create mode 100644 database.py diff --git a/database.py b/database.py new file mode 100644 index 0000000..e8a7b15 --- /dev/null +++ b/database.py @@ -0,0 +1,136 @@ +# database.py + +import sqlite3 +import time + +DB_FILE = "reddit_stocks.db" + +def get_db_connection(): + """Establishes a connection to the SQLite database.""" + conn = sqlite3.connect(DB_FILE) + conn.row_factory = sqlite3.Row + return conn + +def initialize_db(): + """Initializes the database and creates tables if they don't exist.""" + conn = get_db_connection() + cursor = conn.cursor() + + # --- Create tickers table --- + cursor.execute(""" + CREATE TABLE IF NOT EXISTS tickers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT NOT NULL UNIQUE, + market_cap INTEGER, + last_updated INTEGER + ) + """) + + # --- Create subreddits table --- + cursor.execute(""" + CREATE TABLE IF NOT EXISTS subreddits ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE + ) + """) + + # --- Create mentions table --- + cursor.execute(""" + CREATE TABLE IF NOT EXISTS mentions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ticker_id INTEGER, + subreddit_id INTEGER, + post_id TEXT NOT NULL, + mention_timestamp INTEGER NOT NULL, + sentiment_score REAL, + FOREIGN KEY (ticker_id) REFERENCES tickers (id), + FOREIGN KEY (subreddit_id) REFERENCES subreddits (id), + UNIQUE(ticker_id, post_id) + ) + """) + + conn.commit() + conn.close() + print("Database initialized successfully.") + +def get_or_create_entity(conn, table_name, column_name, value): + """Generic function to get or create an entity and return its ID.""" + cursor = conn.cursor() + cursor.execute(f"SELECT id FROM {table_name} WHERE {column_name} = ?", (value,)) + result = cursor.fetchone() + + if result: + return result['id'] + else: + cursor.execute(f"INSERT INTO {table_name} ({column_name}) VALUES (?)", (value,)) + conn.commit() + return cursor.lastrowid + +def add_mention(conn, ticker_id, subreddit_id, post_id, timestamp): + """Adds a new mention to the database, ignoring duplicates.""" + cursor = conn.cursor() + try: + cursor.execute( + "INSERT INTO mentions (ticker_id, subreddit_id, post_id, mention_timestamp) VALUES (?, ?, ?, ?)", + (ticker_id, subreddit_id, post_id, timestamp) + ) + conn.commit() + except sqlite3.IntegrityError: + pass + +def update_ticker_market_cap(conn, ticker_id, market_cap): + """Updates the market cap and timestamp for a specific ticker.""" + cursor = conn.cursor() + current_timestamp = int(time.time()) + cursor.execute( + "UPDATE tickers SET market_cap = ?, last_updated = ? WHERE id = ?", + (market_cap, current_timestamp, ticker_id) + ) + conn.commit() + +def get_ticker_info(conn, ticker_id): + """Retrieves all info for a specific ticker by its ID.""" + cursor = conn.cursor() + cursor.execute("SELECT * FROM tickers WHERE id = ?", (ticker_id,)) + return cursor.fetchone() + +def generate_summary_report(): + """Queries the DB to generate and print a summary with market caps.""" + print("\n--- Summary Report ---") + conn = get_db_connection() + cursor = conn.cursor() + + query = """ + SELECT + t.symbol, + t.market_cap, + COUNT(m.id) as mention_count + FROM mentions m + JOIN tickers t ON m.ticker_id = t.id + GROUP BY t.symbol, t.market_cap + ORDER BY mention_count DESC + LIMIT 20; + """ + + results = cursor.execute(query).fetchall() + + print(f"{'Ticker':<10} | {'Mentions':<10} | {'Market Cap':<20}") + print("-" * 45) + + for row in results: + market_cap_str = "N/A" + if row['market_cap']: + # Format market cap into a readable string (e.g., $1.23T, $45.6B, $123.4M) + mc = row['market_cap'] + if mc >= 1e12: + market_cap_str = f"${mc/1e12:.2f}T" + elif mc >= 1e9: + market_cap_str = f"${mc/1e9:.2f}B" + elif mc >= 1e6: + market_cap_str = f"${mc/1e6:.2f}M" + else: + market_cap_str = f"${mc:,}" + + print(f"{row['symbol']:<10} | {row['mention_count']:<10} | {market_cap_str:<20}") + + conn.close() \ No newline at end of file diff --git a/main.py b/main.py index 0c2dde9..66896bf 100644 --- a/main.py +++ b/main.py @@ -3,19 +3,22 @@ import argparse import json import os -from collections import Counter +import time import praw import yfinance as yf from dotenv import load_dotenv +import database from ticker_extractor import extract_tickers # Load environment variables from .env file load_dotenv() +# How old (in seconds) market cap data can be before we refresh it. 24 hours = 86400 seconds. +MARKET_CAP_REFRESH_INTERVAL = 86400 def load_subreddits(filepath): - """Loads a list of subreddits from a JSON file.""" + # (This function is unchanged) try: with open(filepath, 'r') as f: data = json.load(f) @@ -28,21 +31,17 @@ def load_subreddits(filepath): return None def get_market_cap(ticker_symbol): - """Fetches the market capitalization for a given stock ticker.""" + """Fetches the market capitalization for a given stock ticker from yfinance.""" try: ticker = yf.Ticker(ticker_symbol) - market_cap = ticker.info.get('marketCap') - if market_cap: - # Formatting for better readability - return f"${market_cap:,}" - - return "N/A" - except Exception as e: - # yfinance can sometimes fail for various reasons (e.g., invalid ticker) - return "N/A" + # .info can be slow; .fast_info is a lighter alternative + market_cap = ticker.fast_info.get('marketCap') + return market_cap if market_cap else None + except Exception: + return None def get_reddit_instance(): - """Initializes and returns a PRAW Reddit instance.""" + # (This function is unchanged) client_id = os.getenv("REDDIT_CLIENT_ID") client_secret = os.getenv("REDDIT_CLIENT_SECRET") user_agent = os.getenv("REDDIT_USER_AGENT") @@ -50,41 +49,54 @@ def get_reddit_instance(): if not all([client_id, client_secret, user_agent]): print("Error: Reddit API credentials not found in .env file.") return None + return praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent) - return praw.Reddit( - client_id=client_id, - client_secret=client_secret, - user_agent=user_agent - ) def scan_subreddits(reddit, subreddits_list, post_limit=25): - """Scans subreddits for stock tickers and returns a count of each.""" - all_tickers = Counter() - + """Scans subreddits, stores mentions, and updates market caps in the database.""" + conn = database.get_db_connection() + print(f"\nScanning {len(subreddits_list)} subreddits for top {post_limit} posts...") for subreddit_name in subreddits_list: try: + subreddit_id = database.get_or_create_entity(conn, 'subreddits', 'name', subreddit_name) subreddit = reddit.subreddit(subreddit_name) - print(f"r/{subreddit_name}...") - # Fetch hot posts from the subreddit + print(f"Scanning r/{subreddit_name}...") + for submission in subreddit.hot(limit=post_limit): - # Combine title and selftext for analysis full_text = submission.title + " " + submission.selftext - - # Extract tickers from the combined text tickers_in_post = extract_tickers(full_text) - all_tickers.update(tickers_in_post) - - # Future work: also scan comments - # submission.comments.replace_more(limit=0) # Expand all comment trees - # for comment in submission.comments.list(): - # tickers_in_comment = extract_tickers(comment.body) - # all_tickers.update(tickers_in_comment) + + for ticker_symbol in set(tickers_in_post): + ticker_id = database.get_or_create_entity(conn, 'tickers', 'symbol', ticker_symbol) + + database.add_mention( + conn, + ticker_id=ticker_id, + subreddit_id=subreddit_id, + post_id=submission.id, + timestamp=int(submission.created_utc) + ) + + # --- Check if market cap needs updating --- + ticker_info = database.get_ticker_info(conn, ticker_id) + current_time = int(time.time()) + + if not ticker_info['last_updated'] or (current_time - ticker_info['last_updated'] > MARKET_CAP_REFRESH_INTERVAL): + print(f" -> Fetching market cap for {ticker_symbol}...") + market_cap = get_market_cap(ticker_symbol) + if market_cap: + database.update_ticker_market_cap(conn, ticker_id, market_cap) + else: + # If fetch fails, still update the timestamp so we don't try again for 24 hours + database.update_ticker_market_cap(conn, ticker_id, ticker_info['market_cap']) # Keep old value except Exception as e: print(f"Could not scan r/{subreddit_name}. Error: {e}") + + conn.close() + print("\n--- Scan Complete ---") - return all_tickers def main(): """Main function to run the Reddit stock analysis tool.""" @@ -92,28 +104,21 @@ def main(): parser.add_argument("config_file", help="Path to the JSON file containing subreddits.") args = parser.parse_args() - # --- Part 1: Load Configuration & Initialize Reddit --- + # --- Part 1: Initialize --- + database.initialize_db() + subreddits = load_subreddits(args.config_file) - if not subreddits: - return + if not subreddits: return reddit = get_reddit_instance() - if not reddit: - return + if not reddit: return - # --- Part 2: Scan Reddit for Tickers --- - ticker_counts = scan_subreddits(reddit, subreddits) - if not ticker_counts: - print("No tickers found.") - return + # --- Part 2: Scan and Store --- + scan_subreddits(reddit, subreddits) + + # --- Part 3: Generate and Display Report --- + database.generate_summary_report() - print("\n--- Scan Complete ---") - print("Top 15 mentioned tickers:") - - # --- Part 3: Display Results --- - # We will enrich this data with market cap and sentiment in the next steps - for ticker, count in ticker_counts.most_common(15): - print(f"{ticker}: {count} mentions") if __name__ == "__main__": main() \ No newline at end of file diff --git a/ticker_extractor.py b/ticker_extractor.py index 46d1c6f..eafe882 100644 --- a/ticker_extractor.py +++ b/ticker_extractor.py @@ -11,11 +11,17 @@ COMMON_WORDS_BLACKLIST = { "WAY", "WHO", "WHY", "BIG", "BUY", "SELL", "HOLD", "BE", "GO", "ON", "AT", "IN", "IS", "IT", "OF", "OR", "TO", "WE", "UP", "OUT", "SO", "RH", "SEC", "IRS", "USA", "UK", "EU", - "AI", "ML", "AR", "VR", "NFT", "DAO", "WEB3", "ETH", "BTC", + "AI", "ML", "AR", "VR", "NFT", "DAO", "WEB3", "ETH", "BTC", "DOGE", "USD", "EUR", "GBP", "JPY", "CNY", "INR", "AUD", "CAD", "CHF", "RUB", "ZAR", "BRL", "MXN", "HKD", "SGD", "NZD", "RSD", "JPY", "KRW", "SEK", "NOK", "DKK", "PLN", "CZK", "HUF", "TRY", - "US", "IRA", "FDA", "SEC", "FBI", "CIA", "NSA", "NATO", + "US", "IRA", "FDA", "SEC", "FBI", "CIA", "NSA", "NATO", "FINRA", + "NASDAQ", "NYSE", "AMEX", "FTSE", "DAX", "WSB", "SPX", "DJIA", + "EDGAR", "GDP", "CPI", "PPI", "PMI", "ISM", "FOMC", "ECB", "BOE", + "BOJ", "RBA", "RBNZ", "BIS", "NFA", "P", "VOO", "CTB", "DR", + "ETF", "EV", "ESG", "REIT", "SPAC", "IPO", "M&A", "LBO", + "Q1", "Q2", "Q3", "Q4", "FY", "FAQ", "ROI", "ROE", "EPS", "P/E", "PEG", + "FRG", "FXAIX", "FXIAX", "FZROX" } def extract_tickers(text):