# database.py import sqlite3 import time DB_FILE = "reddit_stocks.db" def get_db_connection(): """Establishes a connection to the SQLite database.""" conn = sqlite3.connect(DB_FILE) conn.row_factory = sqlite3.Row return conn def initialize_db(): """Initializes the database and creates tables if they don't exist.""" conn = get_db_connection() cursor = conn.cursor() # --- Create tickers table --- cursor.execute(""" CREATE TABLE IF NOT EXISTS tickers ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT NOT NULL UNIQUE, market_cap INTEGER, last_updated INTEGER ) """) # --- Create subreddits table --- cursor.execute(""" CREATE TABLE IF NOT EXISTS subreddits ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE ) """) # --- Create mentions table --- cursor.execute(""" CREATE TABLE IF NOT EXISTS mentions ( id INTEGER PRIMARY KEY AUTOINCREMENT, ticker_id INTEGER, subreddit_id INTEGER, post_id TEXT NOT NULL, mention_timestamp INTEGER NOT NULL, sentiment_score REAL, FOREIGN KEY (ticker_id) REFERENCES tickers (id), FOREIGN KEY (subreddit_id) REFERENCES subreddits (id), UNIQUE(ticker_id, post_id) ) """) conn.commit() conn.close() print("Database initialized successfully.") def get_or_create_entity(conn, table_name, column_name, value): """Generic function to get or create an entity and return its ID.""" cursor = conn.cursor() cursor.execute(f"SELECT id FROM {table_name} WHERE {column_name} = ?", (value,)) result = cursor.fetchone() if result: return result['id'] else: cursor.execute(f"INSERT INTO {table_name} ({column_name}) VALUES (?)", (value,)) conn.commit() return cursor.lastrowid def add_mention(conn, ticker_id, subreddit_id, post_id, timestamp): """Adds a new mention to the database, ignoring duplicates.""" cursor = conn.cursor() try: cursor.execute( "INSERT INTO mentions (ticker_id, subreddit_id, post_id, mention_timestamp) VALUES (?, ?, ?, ?)", (ticker_id, subreddit_id, post_id, timestamp) ) conn.commit() except sqlite3.IntegrityError: pass def update_ticker_market_cap(conn, ticker_id, market_cap): """Updates the market cap and timestamp for a specific ticker.""" cursor = conn.cursor() current_timestamp = int(time.time()) cursor.execute( "UPDATE tickers SET market_cap = ?, last_updated = ? WHERE id = ?", (market_cap, current_timestamp, ticker_id) ) conn.commit() def get_ticker_info(conn, ticker_id): """Retrieves all info for a specific ticker by its ID.""" cursor = conn.cursor() cursor.execute("SELECT * FROM tickers WHERE id = ?", (ticker_id,)) return cursor.fetchone() def generate_summary_report(): """Queries the DB to generate and print a summary with market caps.""" print("\n--- Summary Report ---") conn = get_db_connection() cursor = conn.cursor() query = """ SELECT t.symbol, t.market_cap, COUNT(m.id) as mention_count FROM mentions m JOIN tickers t ON m.ticker_id = t.id GROUP BY t.symbol, t.market_cap ORDER BY mention_count DESC LIMIT 20; """ results = cursor.execute(query).fetchall() print(f"{'Ticker':<10} | {'Mentions':<10} | {'Market Cap':<20}") print("-" * 45) for row in results: market_cap_str = "N/A" if row['market_cap']: # Format market cap into a readable string (e.g., $1.23T, $45.6B, $123.4M) mc = row['market_cap'] if mc >= 1e12: market_cap_str = f"${mc/1e12:.2f}T" elif mc >= 1e9: market_cap_str = f"${mc/1e9:.2f}B" elif mc >= 1e6: market_cap_str = f"${mc/1e6:.2f}M" else: market_cap_str = f"${mc:,}" print(f"{row['symbol']:<10} | {row['mention_count']:<10} | {market_cap_str:<20}") conn.close()