diff --git a/rstat_tool/database.py b/rstat_tool/database.py index 437bd0e..9003385 100644 --- a/rstat_tool/database.py +++ b/rstat_tool/database.py @@ -9,72 +9,6 @@ from datetime import datetime, timedelta, timezone DB_FILE = "reddit_stocks.db" MARKET_CAP_REFRESH_INTERVAL = 86400 -def get_db_connection(): - """Establishes a connection to the SQLite database.""" - conn = sqlite3.connect(DB_FILE) - conn.row_factory = sqlite3.Row - return conn - -def initialize_db(): - """ - Initializes the database and creates the necessary tables if they don't exist. - """ - conn = get_db_connection() - cursor = conn.cursor() - - # --- Create tickers table --- - cursor.execute(""" - CREATE TABLE IF NOT EXISTS tickers ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT NOT NULL UNIQUE, - market_cap INTEGER, - closing_price REAL, - last_updated INTEGER - ) - """) - - # --- Create subreddits table --- - cursor.execute(""" - CREATE TABLE IF NOT EXISTS subreddits ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE - ) - """) - - # --- Create mentions table --- - cursor.execute(""" - CREATE TABLE IF NOT EXISTS mentions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ticker_id INTEGER, - subreddit_id INTEGER, - post_id TEXT NOT NULL, - mention_type TEXT NOT NULL, - mention_sentiment REAL, -- Renamed from sentiment_score for clarity - post_avg_sentiment REAL, -- NEW: Stores the avg sentiment of the whole post - mention_timestamp INTEGER NOT NULL, - FOREIGN KEY (ticker_id) REFERENCES tickers (id), - FOREIGN KEY (subreddit_id) REFERENCES subreddits (id) - ) - """) - - cursor.execute(""" - CREATE TABLE IF NOT EXISTS posts ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - post_id TEXT NOT NULL UNIQUE, - title TEXT NOT NULL, - post_url TEXT, - subreddit_id INTEGER, - post_timestamp INTEGER, - comment_count INTEGER, - avg_comment_sentiment REAL, - FOREIGN KEY (subreddit_id) REFERENCES subreddits (id) - ) - """) - - conn.commit() - conn.close() - log.info("Database initialized successfully.") - def clean_stale_tickers(): """ Removes tickers and their associated mentions from the database @@ -381,4 +315,16 @@ def get_all_tickers(): conn = get_db_connection() results = conn.execute("SELECT id, symbol FROM tickers;").fetchall() conn.close() - return results \ No newline at end of file + return results + +def get_ticker_by_symbol(symbol): + """ + Retrieves a single ticker's ID and symbol from the database. + The search is case-insensitive. Returns a Row object or None if not found. + """ + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute("SELECT id, symbol FROM tickers WHERE LOWER(symbol) = LOWER(?)", (symbol,)) + result = cursor.fetchone() + conn.close() + return result \ No newline at end of file diff --git a/rstat_tool/main.py b/rstat_tool/main.py index ed23edd..f4ed819 100644 --- a/rstat_tool/main.py +++ b/rstat_tool/main.py @@ -160,12 +160,19 @@ def main(): """Main function to run the Reddit stock analysis tool.""" parser = argparse.ArgumentParser(description="Analyze stock ticker mentions on Reddit.", formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument("-u", "--update-financials-only", action="store_true", help="Skip Reddit scan and only update financial data for all existing tickers.") - parser.add_argument("-f", "--config", default="subreddits.json", help="Path to the JSON file for scanning.") + parser.add_argument( + "-u", "--update-financials-only", + nargs='?', + const="ALL_TICKERS", # A special value to signify "update all" + default=None, + metavar='TICKER', + help="Update financials. Provide a ticker symbol to update just one,\nor use the flag alone to update all tickers in the database." + ) + parser.add_argument("-f", "--config", default="subreddits.json", help="Path to the JSON file for scanning. (Default: subreddits.json)") parser.add_argument("-s", "--subreddit", help="Scan a single subreddit, ignoring the config file.") - parser.add_argument("-d", "--days", type=int, default=1, help="Number of past days to scan for new posts.") - parser.add_argument("-p", "--posts", type=int, default=200, help="Max posts to check per subreddit.") - parser.add_argument("-c", "--comments", type=int, default=100, help="Number of comments to scan per post.") + parser.add_argument("-d", "--days", type=int, default=1, help="Number of past days to scan for new posts. (Default: 1)") + parser.add_argument("-p", "--posts", type=int, default=200, help="Max posts to check per subreddit. (Default: 200)") + parser.add_argument("-c", "--comments", type=int, default=100, help="Number of comments to scan per post. (Default: 100)") parser.add_argument("--debug", action="store_true", help="Enable detailed debug logging to the console.") parser.add_argument("--stdout", action="store_true", help="Print all log messages to the console.") @@ -174,23 +181,49 @@ def main(): database.initialize_db() - if args.update_financials_only: - log.critical("--- Starting Financial Data Update Only Mode (using isolated fetcher) ---") - all_tickers = database.get_all_tickers() - log.info(f"Found {len(all_tickers)} tickers in the database to update.") + update_mode = args.update_financials_only + + if update_mode: # This block runs if -u or --update-financials-only was used + if update_mode == "ALL_TICKERS": + # This is the "update all" case + log.critical("--- Starting Financial Data Update for ALL tickers ---") + all_tickers = database.get_all_tickers() + log.info(f"Found {len(all_tickers)} tickers in the database to update.") + + conn = database.get_db_connection() + for ticker in all_tickers: + symbol = ticker['symbol'] + log.info(f" -> Updating financials for {symbol}...") + financials = get_financial_data_via_fetcher(symbol) + database.update_ticker_financials( + conn, ticker['id'], + financials.get('market_cap'), + financials.get('closing_price') + ) + conn.close() + log.critical("--- Financial Data Update Complete ---") - conn = database.get_db_connection() - for ticker in all_tickers: - symbol = ticker['symbol'] - log.info(f" -> Updating financials for {symbol}...") - financials = get_financial_data_via_fetcher(symbol) - database.update_ticker_financials( - conn, ticker['id'], - financials.get('market_cap'), - financials.get('closing_price') - ) - conn.close() - log.critical("--- Financial Data Update Complete ---") + else: + # This is the "update single ticker" case + ticker_symbol_to_update = update_mode + log.critical(f"--- Starting Financial Data Update for single ticker: {ticker_symbol_to_update} ---") + + # Find the ticker in the database + ticker_info = database.get_ticker_by_symbol(ticker_symbol_to_update) + + if ticker_info: + conn = database.get_db_connection() + log.info(f" -> Updating financials for {ticker_info['symbol']}...") + financials = get_financial_data_via_fetcher(ticker_info['symbol']) + database.update_ticker_financials( + conn, ticker_info['id'], + financials.get('market_cap'), + financials.get('closing_price') + ) + conn.close() + log.critical("--- Financial Data Update Complete ---") + else: + log.error(f"Ticker '{ticker_symbol_to_update}' not found in the database. Please run a scan first to discover it.") else: log.critical("--- Starting Reddit Scan Mode ---") if args.subreddit: