diff --git a/export_image.py b/export_image.py index 29930f6..edde071 100644 --- a/export_image.py +++ b/export_image.py @@ -8,6 +8,7 @@ from playwright.sync_api import sync_playwright # Define the output directory as a constant OUTPUT_DIR = "images" + def export_image(url_path, filename_prefix): """ Launches a headless browser, navigates to a URL path, and screenshots @@ -20,7 +21,7 @@ def export_image(url_path, filename_prefix): base_url = "http://127.0.0.1:5000" url = f"{base_url}/{url_path}" - + # 2. Construct the full output path including the new directory output_file = os.path.join(OUTPUT_DIR, f"{filename_prefix}_{int(time.time())}.png") @@ -28,24 +29,26 @@ def export_image(url_path, filename_prefix): try: browser = p.chromium.launch() page = browser.new_page() - + page.set_viewport_size({"width": 1920, "height": 1080}) - + print(f" Navigating to {url}...") - page.goto(url, wait_until="networkidle") # Wait for network to be idle - + page.goto(url, wait_until="networkidle") # Wait for network to be idle + # Target the specific element we want to screenshot element = page.locator(".image-container") - + print(f" Saving screenshot to {output_file}...") element.screenshot(path=output_file) - + browser.close() print(f"-> Export complete! Image saved to {output_file}") except Exception as e: print(f"\nAn error occurred during export: {e}") - print("Please ensure the 'rstat-dashboard' server is running in another terminal.") + print( + "Please ensure the 'rstat-dashboard' server is running in another terminal." + ) if __name__ == "__main__": @@ -53,9 +56,16 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Export subreddit sentiment images.") group = parser.add_mutually_exclusive_group(required=True) group.add_argument("-s", "--subreddit", help="The name of the subreddit to export.") - group.add_argument("-o", "--overall", action="store_true", help="Export the overall summary image.") - - parser.add_argument("-w", "--weekly", action="store_true", help="Export the weekly view instead of the daily view (only for --subreddit).") + group.add_argument( + "-o", "--overall", action="store_true", help="Export the overall summary image." + ) + + parser.add_argument( + "-w", + "--weekly", + action="store_true", + help="Export the weekly view instead of the daily view (only for --subreddit).", + ) args = parser.parse_args() # Determine the correct URL path and filename based on arguments @@ -65,9 +75,9 @@ if __name__ == "__main__": url_path_to_render = f"subreddit/{args.subreddit}?view={view_type}&image=true" filename_prefix_to_save = f"{args.subreddit}_{view_type}" export_image(url_path_to_render, filename_prefix_to_save) - + elif args.overall: # For overall, we assume daily view for the image url_path_to_render = "/?view=daily&image=true" filename_prefix_to_save = "overall_summary_daily" - export_image(url_path_to_render, filename_prefix_to_save) \ No newline at end of file + export_image(url_path_to_render, filename_prefix_to_save) diff --git a/fetch_close_price.py b/fetch_close_price.py index 5d0be19..d52ecf7 100644 --- a/fetch_close_price.py +++ b/fetch_close_price.py @@ -13,26 +13,25 @@ if __name__ == "__main__": if len(sys.argv) < 2: # Exit with an error code if no ticker is provided sys.exit(1) - + ticker_symbol = sys.argv[1] - + try: # Instead of the global yf.download(), we use the Ticker object's .history() method. # This uses a different internal code path that we have proven is stable. ticker = yf.Ticker(ticker_symbol) data = ticker.history(period="2d", auto_adjust=False) - # --- END OF FIX --- - + closing_price = None if not data.empty: - last_close_raw = data['Close'].iloc[-1] + last_close_raw = data["Close"].iloc[-1] if pd.notna(last_close_raw): closing_price = float(last_close_raw) - + # On success, print JSON to stdout and exit cleanly print(json.dumps({"closing_price": closing_price})) sys.exit(0) except Exception: # If any error occurs, print an empty JSON and exit with an error code print(json.dumps({"closing_price": None})) - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/fetch_market_cap.py b/fetch_market_cap.py index 9c2a557..87b415a 100644 --- a/fetch_market_cap.py +++ b/fetch_market_cap.py @@ -12,17 +12,17 @@ if __name__ == "__main__": if len(sys.argv) < 2: # Exit with an error code if no ticker is provided sys.exit(1) - + ticker_symbol = sys.argv[1] - + try: # Directly get the market cap - market_cap = yf.Ticker(ticker_symbol).info.get('marketCap') - + market_cap = yf.Ticker(ticker_symbol).info.get("marketCap") + # On success, print JSON to stdout and exit cleanly print(json.dumps({"market_cap": market_cap})) sys.exit(0) except Exception: # If any error occurs, print an empty JSON and exit with an error code print(json.dumps({"market_cap": None})) - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/get_refresh_token.py b/get_refresh_token.py index 66bd12e..6f7f768 100644 --- a/get_refresh_token.py +++ b/get_refresh_token.py @@ -8,7 +8,8 @@ import random import socket # --- IMPORTANT: Ensure this matches the "redirect uri" in your Reddit App settings --- -REDIRECT_URI = "http://localhost:5000" +REDIRECT_URI = "http://localhost:5000" + def main(): print("--- RSTAT Refresh Token Generator ---") @@ -17,7 +18,9 @@ def main(): client_secret = os.getenv("REDDIT_CLIENT_SECRET") if not all([client_id, client_secret]): - print("Error: REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET must be set in your .env file.") + print( + "Error: REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET must be set in your .env file." + ) return # 1. Initialize PRAW @@ -25,44 +28,55 @@ def main(): client_id=client_id, client_secret=client_secret, redirect_uri=REDIRECT_URI, - user_agent="rstat_token_fetcher (by u/YourUsername)" # Can be anything + user_agent="rstat_token_fetcher (by u/YourUsername)", # Can be anything ) # 2. Generate the authorization URL # Scopes define what our script is allowed to do. 'identity' and 'submit' are needed. - scopes = ["identity", "submit", "read"] + scopes = ["identity", "submit", "read"] state = str(random.randint(0, 65536)) auth_url = reddit.auth.url(scopes, state, "permanent") print("\nStep 1: Open this URL in your browser:\n") print(auth_url) - - print("\nStep 2: Log in to Reddit, click 'Allow', and you'll be redirected to a 'page not found'.") - print("Step 3: Copy the ENTIRE URL from your browser's address bar after the redirect.") - + + print( + "\nStep 2: Log in to Reddit, click 'Allow', and you'll be redirected to a 'page not found'." + ) + print( + "Step 3: Copy the ENTIRE URL from your browser's address bar after the redirect." + ) + # 3. Get the redirected URL from the user - redirected_url = input("\nStep 4: Paste the full redirected URL here and press Enter:\n> ") + redirected_url = input( + "\nStep 4: Paste the full redirected URL here and press Enter:\n> " + ) # 4. Exchange the authorization code for a refresh token try: # The state is used to prevent CSRF attacks, we're just checking it matches assert state == redirected_url.split("state=")[1].split("&")[0] code = redirected_url.split("code=")[1].split("#_")[0] - + print("\nAuthorization code received. Fetching refresh token...") - + # This is the line that gets the key! refresh_token = reddit.auth.authorize(code) - + print("\n--- SUCCESS! ---") print("Your Refresh Token is:\n") print(refresh_token) - print("\nStep 5: Copy this token and add it to your .env file as REDDIT_REFRESH_TOKEN.") - print("Step 6: You can now delete your REDDIT_USERNAME and REDDIT_PASSWORD from the .env file.") + print( + "\nStep 5: Copy this token and add it to your .env file as REDDIT_REFRESH_TOKEN." + ) + print( + "Step 6: You can now delete your REDDIT_USERNAME and REDDIT_PASSWORD from the .env file." + ) except Exception as e: print(f"\nAn error occurred: {e}") print("Please make sure you copied the full URL.") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/post_to_reddit.py b/post_to_reddit.py index be60a57..ee69cf9 100644 --- a/post_to_reddit.py +++ b/post_to_reddit.py @@ -11,28 +11,32 @@ from pathlib import Path # --- CONFIGURATION --- IMAGE_DIR = "images" + def get_reddit_instance(): """Initializes and returns a PRAW Reddit instance using OAuth2 refresh token.""" - env_path = Path(__file__).parent / '.env' + env_path = Path(__file__).parent / ".env" load_dotenv(dotenv_path=env_path) - + client_id = os.getenv("REDDIT_CLIENT_ID") client_secret = os.getenv("REDDIT_CLIENT_SECRET") user_agent = os.getenv("REDDIT_USER_AGENT") refresh_token = os.getenv("REDDIT_REFRESH_TOKEN") if not all([client_id, client_secret, user_agent, refresh_token]): - print("Error: Reddit API credentials (including REDDIT_REFRESH_TOKEN) must be set in .env file.") + print( + "Error: Reddit API credentials (including REDDIT_REFRESH_TOKEN) must be set in .env file." + ) return None - + return praw.Reddit( client_id=client_id, client_secret=client_secret, user_agent=user_agent, - refresh_token=refresh_token + refresh_token=refresh_token, ) + def find_latest_image(pattern): """Finds the most recent file in the IMAGE_DIR that matches a given pattern.""" try: @@ -47,17 +51,34 @@ def find_latest_image(pattern): print(f"Error finding image file: {e}") return None + def main(): """Main function to find an image and post it to Reddit.""" - parser = argparse.ArgumentParser(description="Find the latest sentiment image and post it to a subreddit.") - parser.add_argument("-s", "--subreddit", help="The source subreddit of the image to post. (Defaults to overall summary)") - parser.add_argument("-w", "--weekly", action="store_true", help="Post the weekly summary instead of the daily one.") - parser.add_argument("-t", "--target-subreddit", default="rstat", help="The subreddit to post the image to. (Default: rstat)") + parser = argparse.ArgumentParser( + description="Find the latest sentiment image and post it to a subreddit." + ) + parser.add_argument( + "-s", + "--subreddit", + help="The source subreddit of the image to post. (Defaults to overall summary)", + ) + parser.add_argument( + "-w", + "--weekly", + action="store_true", + help="Post the weekly summary instead of the daily one.", + ) + parser.add_argument( + "-t", + "--target-subreddit", + default="rstat", + help="The subreddit to post the image to. (Default: rstat)", + ) args = parser.parse_args() # --- 1. Determine filename pattern and post title --- current_date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d") - + if args.subreddit: view_type = "weekly" if args.weekly else "daily" filename_pattern = f"{args.subreddit.lower()}_{view_type}_*.png" @@ -65,36 +86,42 @@ def main(): else: # Default to the overall summary if args.weekly: - print("Warning: --weekly flag has no effect for overall summary. Posting overall daily image.") + print( + "Warning: --weekly flag has no effect for overall summary. Posting overall daily image." + ) filename_pattern = "overall_summary_*.png" - post_title = f"Overall Top 10 Ticker Mentions Across Reddit ({current_date_str})" - + post_title = ( + f"Overall Top 10 Ticker Mentions Across Reddit ({current_date_str})" + ) + print(f"Searching for image pattern: {filename_pattern}") - + # --- 2. Find the latest image file --- image_to_post = find_latest_image(filename_pattern) - + if not image_to_post: - print(f"Error: No image found matching the pattern '{filename_pattern}'. Please run the scraper and exporter first.") + print( + f"Error: No image found matching the pattern '{filename_pattern}'. Please run the scraper and exporter first." + ) return print(f"Found image: {image_to_post}") - + # --- 3. Connect to Reddit and submit --- reddit = get_reddit_instance() if not reddit: return - + try: target_sub = reddit.subreddit(args.target_subreddit) print(f"Submitting '{post_title}' to r/{target_sub.display_name}...") - + submission = target_sub.submit_image( title=post_title, image_path=image_to_post, - flair_id=None # Optional: You can add a flair ID here if you want + flair_id=None, # Optional: You can add a flair ID here if you want ) - + print("\n--- Post Successful! ---") print(f"Post URL: {submission.shortlink}") @@ -103,4 +130,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/rstat_tool/cleanup.py b/rstat_tool/cleanup.py index ddd3832..dee1232 100644 --- a/rstat_tool/cleanup.py +++ b/rstat_tool/cleanup.py @@ -3,45 +3,56 @@ import argparse from . import database from .logger_setup import setup_logging, logger as log + # We can't reuse load_subreddits from main anymore if it's not in the same file # So we will duplicate it here. It's small and keeps this script self-contained. import json + def load_subreddits(filepath): """Loads a list of subreddits from a JSON file.""" try: - with open(filepath, 'r') as f: + with open(filepath, "r") as f: data = json.load(f) return data.get("subreddits", []) except (FileNotFoundError, json.JSONDecodeError) as e: log.error(f"Error loading config file '{filepath}': {e}") return None + def run_cleanup(): """Main function for the cleanup tool.""" parser = argparse.ArgumentParser( description="A tool to clean stale data from the RSTAT database.", - formatter_class=argparse.RawTextHelpFormatter + formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument("--tickers", action="store_true", help="Clean tickers that are in the blacklist.") - + parser.add_argument( + "--tickers", + action="store_true", + help="Clean tickers that are in the blacklist.", + ) + # --- UPDATED ARGUMENT DEFINITION --- # nargs='?': Makes the argument optional. # const='subreddits.json': The value used if the flag is present with no argument. # default=None: The value if the flag is not present at all. parser.add_argument( - "--subreddits", - nargs='?', - const='subreddits.json', + "--subreddits", + nargs="?", + const="subreddits.json", default=None, - help="Clean data from subreddits NOT in the specified config file.\n(Defaults to 'subreddits.json' if flag is used without a value)." + help="Clean data from subreddits NOT in the specified config file.\n(Defaults to 'subreddits.json' if flag is used without a value).", ) - - parser.add_argument("--all", action="store_true", help="Run all available cleanup tasks.") - parser.add_argument("--stdout", action="store_true", help="Print all log messages to the console.") - + + parser.add_argument( + "--all", action="store_true", help="Run all available cleanup tasks." + ) + parser.add_argument( + "--stdout", action="store_true", help="Print all log messages to the console." + ) + args = parser.parse_args() - + setup_logging(console_verbose=args.stdout) run_any_task = False @@ -57,7 +68,7 @@ def run_cleanup(): if args.all or args.subreddits is not None: run_any_task = True # If --all is used, default to 'subreddits.json' if --subreddits wasn't also specified - config_file = args.subreddits or 'subreddits.json' + config_file = args.subreddits or "subreddits.json" log.info(f"\nCleaning subreddits based on active list in: {config_file}") active_subreddits = load_subreddits(config_file) if active_subreddits is not None: @@ -65,10 +76,13 @@ def run_cleanup(): if not run_any_task: parser.print_help() - log.error("\nError: Please provide at least one cleanup option (e.g., --tickers, --subreddits, --all).") + log.error( + "\nError: Please provide at least one cleanup option (e.g., --tickers, --subreddits, --all)." + ) return log.critical("\nCleanup finished.") + if __name__ == "__main__": - run_cleanup() \ No newline at end of file + run_cleanup() diff --git a/rstat_tool/dashboard.py b/rstat_tool/dashboard.py index 89e5974..ce15578 100644 --- a/rstat_tool/dashboard.py +++ b/rstat_tool/dashboard.py @@ -8,13 +8,14 @@ from .database import ( get_deep_dive_details, get_daily_summary_for_subreddit, get_weekly_summary_for_subreddit, - get_overall_daily_summary, # Now correctly imported - get_overall_weekly_summary # Now correctly imported + get_overall_daily_summary, # Now correctly imported + get_overall_weekly_summary, # Now correctly imported ) -app = Flask(__name__, template_folder='../templates') +app = Flask(__name__, template_folder="../templates") -@app.template_filter('format_mc') + +@app.template_filter("format_mc") def format_market_cap(mc): """Formats a large number into a readable market cap string.""" if mc is None or mc == 0: @@ -28,26 +29,28 @@ def format_market_cap(mc): else: return f"${mc:,}" + @app.context_processor def inject_subreddits(): """Makes the list of all subreddits available to every template for the navbar.""" return dict(all_subreddits=get_all_scanned_subreddits()) + @app.route("/") def overall_dashboard(): """Handler for the main, overall dashboard.""" - view_type = request.args.get('view', 'daily') - is_image_mode = request.args.get('image') == 'true' - - if view_type == 'weekly': + view_type = request.args.get("view", "daily") + is_image_mode = request.args.get("image") == "true" + + if view_type == "weekly": tickers, start, end = get_overall_weekly_summary() date_string = f"{start.strftime('%b %d')} - {end.strftime('%b %d, %Y')}" subtitle = "All Subreddits - Top 10 Weekly" - else: # Default to daily + else: # Default to daily tickers = get_overall_daily_summary() date_string = datetime.now(timezone.utc).strftime("%Y-%m-%d") subtitle = "All Subreddits - Top 10 Daily" - + return render_template( "dashboard_view.html", title="Overall Dashboard", @@ -57,26 +60,27 @@ def overall_dashboard(): view_type=view_type, subreddit_name=None, is_image_mode=is_image_mode, - base_url="/" + base_url="/", ) + @app.route("/subreddit/") def subreddit_dashboard(name): """Handler for per-subreddit dashboards.""" - view_type = request.args.get('view', 'daily') - is_image_mode = request.args.get('image') == 'true' - - if view_type == 'weekly': + view_type = request.args.get("view", "daily") + is_image_mode = request.args.get("image") == "true" + + if view_type == "weekly": today = datetime.now(timezone.utc) target_date = today - timedelta(days=7) tickers, start, end = get_weekly_summary_for_subreddit(name, target_date) date_string = f"{start.strftime('%b %d')} - {end.strftime('%b %d, %Y')}" subtitle = f"r/{name} - Top 10 Weekly" - else: # Default to daily + else: # Default to daily tickers = get_daily_summary_for_subreddit(name) date_string = datetime.now(timezone.utc).strftime("%Y-%m-%d") subtitle = f"r/{name} - Top 10 Daily" - + return render_template( "dashboard_view.html", title=f"r/{name} Dashboard", @@ -86,9 +90,10 @@ def subreddit_dashboard(name): view_type=view_type, subreddit_name=name, is_image_mode=is_image_mode, - base_url=f"/subreddit/{name}" + base_url=f"/subreddit/{name}", ) + @app.route("/deep-dive/") def deep_dive(symbol): """The handler for the deep-dive page for a specific ticker.""" @@ -96,6 +101,7 @@ def deep_dive(symbol): posts = get_deep_dive_details(symbol) return render_template("deep_dive.html", posts=posts, symbol=symbol) + def start_dashboard(): """The main function called by the 'rstat-dashboard' command.""" log.info("Starting Flask server...") @@ -103,5 +109,6 @@ def start_dashboard(): log.info("Press CTRL+C to stop the server.") app.run(debug=True) + if __name__ == "__main__": - start_dashboard() \ No newline at end of file + start_dashboard() diff --git a/rstat_tool/database.py b/rstat_tool/database.py index 158a744..bd4d126 100644 --- a/rstat_tool/database.py +++ b/rstat_tool/database.py @@ -9,6 +9,7 @@ from datetime import datetime, timedelta, timezone DB_FILE = "reddit_stocks.db" MARKET_CAP_REFRESH_INTERVAL = 86400 + def clean_stale_tickers(): """ Removes tickers and their associated mentions from the database @@ -18,9 +19,9 @@ def clean_stale_tickers(): conn = get_db_connection() cursor = conn.cursor() - placeholders = ','.join('?' for _ in COMMON_WORDS_BLACKLIST) + placeholders = ",".join("?" for _ in COMMON_WORDS_BLACKLIST) query = f"SELECT id, symbol FROM tickers WHERE symbol IN ({placeholders})" - + cursor.execute(query, tuple(COMMON_WORDS_BLACKLIST)) stale_tickers = cursor.fetchall() @@ -30,17 +31,18 @@ def clean_stale_tickers(): return for ticker in stale_tickers: - ticker_id = ticker['id'] - ticker_symbol = ticker['symbol'] + ticker_id = ticker["id"] + ticker_symbol = ticker["symbol"] log.info(f"Removing stale ticker '{ticker_symbol}' (ID: {ticker_id})...") cursor.execute("DELETE FROM mentions WHERE ticker_id = ?", (ticker_id,)) cursor.execute("DELETE FROM tickers WHERE id = ?", (ticker_id,)) - + deleted_count = conn.total_changes conn.commit() conn.close() log.info(f"Cleanup complete. Removed {deleted_count} records.") + def clean_stale_subreddits(active_subreddits): """ Removes all data associated with subreddits that are NOT in the active list. @@ -57,9 +59,9 @@ def clean_stale_subreddits(active_subreddits): db_subreddits = cursor.fetchall() stale_sub_ids = [] for sub in db_subreddits: - if sub['name'] not in active_subreddits_lower: + if sub["name"] not in active_subreddits_lower: log.info(f"Found stale subreddit to remove: r/{sub['name']}") - stale_sub_ids.append(sub['id']) + stale_sub_ids.append(sub["id"]) if not stale_sub_ids: log.info("No stale subreddits to clean.") conn.close() @@ -73,15 +75,18 @@ def clean_stale_subreddits(active_subreddits): conn.close() log.info("Stale subreddit cleanup complete.") + def get_db_connection(): conn = sqlite3.connect(DB_FILE) conn.row_factory = sqlite3.Row return conn + def initialize_db(): conn = get_db_connection() cursor = conn.cursor() - cursor.execute(""" + cursor.execute( + """ CREATE TABLE IF NOT EXISTS tickers ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT NOT NULL UNIQUE, @@ -89,14 +94,18 @@ def initialize_db(): closing_price REAL, last_updated INTEGER ) - """) - cursor.execute(""" + """ + ) + cursor.execute( + """ CREATE TABLE IF NOT EXISTS subreddits ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL UNIQUE ) - """) - cursor.execute(""" + """ + ) + cursor.execute( + """ CREATE TABLE IF NOT EXISTS mentions ( id INTEGER PRIMARY KEY AUTOINCREMENT, ticker_id INTEGER, @@ -109,8 +118,10 @@ def initialize_db(): FOREIGN KEY (ticker_id) REFERENCES tickers (id), FOREIGN KEY (subreddit_id) REFERENCES subreddits (id) ) - """) - cursor.execute(""" + """ + ) + cursor.execute( + """ CREATE TABLE IF NOT EXISTS posts ( id INTEGER PRIMARY KEY AUTOINCREMENT, post_id TEXT NOT NULL UNIQUE, @@ -122,12 +133,23 @@ def initialize_db(): avg_comment_sentiment REAL, FOREIGN KEY (subreddit_id) REFERENCES subreddits (id) ) - """) + """ + ) conn.commit() conn.close() log.info("Database initialized successfully.") -def add_mention(conn, ticker_id, subreddit_id, post_id, mention_type, timestamp, mention_sentiment, post_avg_sentiment=None): + +def add_mention( + conn, + ticker_id, + subreddit_id, + post_id, + mention_type, + timestamp, + mention_sentiment, + post_avg_sentiment=None, +): cursor = conn.cursor() try: cursor.execute( @@ -135,40 +157,52 @@ def add_mention(conn, ticker_id, subreddit_id, post_id, mention_type, timestamp, INSERT INTO mentions (ticker_id, subreddit_id, post_id, mention_type, mention_timestamp, mention_sentiment, post_avg_sentiment) VALUES (?, ?, ?, ?, ?, ?, ?) """, - (ticker_id, subreddit_id, post_id, mention_type, timestamp, mention_sentiment, post_avg_sentiment) + ( + ticker_id, + subreddit_id, + post_id, + mention_type, + timestamp, + mention_sentiment, + post_avg_sentiment, + ), ) conn.commit() except sqlite3.IntegrityError: pass + def get_or_create_entity(conn, table_name, column_name, value): """Generic function to get or create an entity and return its ID.""" cursor = conn.cursor() cursor.execute(f"SELECT id FROM {table_name} WHERE {column_name} = ?", (value,)) result = cursor.fetchone() if result: - return result['id'] + return result["id"] else: cursor.execute(f"INSERT INTO {table_name} ({column_name}) VALUES (?)", (value,)) conn.commit() return cursor.lastrowid + def update_ticker_financials(conn, ticker_id, market_cap, closing_price): """Updates the financials and timestamp for a specific ticker.""" cursor = conn.cursor() current_timestamp = int(time.time()) cursor.execute( "UPDATE tickers SET market_cap = ?, closing_price = ?, last_updated = ? WHERE id = ?", - (market_cap, closing_price, current_timestamp, ticker_id) + (market_cap, closing_price, current_timestamp, ticker_id), ) conn.commit() + def get_ticker_info(conn, ticker_id): """Retrieves all info for a specific ticker by its ID.""" cursor = conn.cursor() cursor.execute("SELECT * FROM tickers WHERE id = ?", (ticker_id,)) return cursor.fetchone() + def get_week_start_end(for_date): """ Calculates the start (Monday, 00:00:00) and end (Sunday, 23:59:59) @@ -178,13 +212,14 @@ def get_week_start_end(for_date): # Monday is 0, Sunday is 6 start_of_week = for_date - timedelta(days=for_date.weekday()) end_of_week = start_of_week + timedelta(days=6) - + # Set time to the very beginning and very end of the day for an inclusive range start_of_week = start_of_week.replace(hour=0, minute=0, second=0, microsecond=0) end_of_week = end_of_week.replace(hour=23, minute=59, second=59, microsecond=999999) - + return start_of_week, end_of_week + def add_or_update_post_analysis(conn, post_data): """ Inserts a new post analysis record or updates an existing one. @@ -200,10 +235,11 @@ def add_or_update_post_analysis(conn, post_data): comment_count = excluded.comment_count, avg_comment_sentiment = excluded.avg_comment_sentiment; """, - post_data + post_data, ) conn.commit() + def get_overall_summary(limit=10): """ Gets the top tickers across all subreddits from the LAST 24 HOURS. @@ -211,7 +247,7 @@ def get_overall_summary(limit=10): conn = get_db_connection() one_day_ago = datetime.now(timezone.utc) - timedelta(days=1) one_day_ago_timestamp = int(one_day_ago.timestamp()) - + query = """ SELECT t.symbol, t.market_cap, t.closing_price, COUNT(m.id) as mention_count, SUM(CASE WHEN m.mention_sentiment > 0.1 THEN 1 ELSE 0 END) as bullish_mentions, @@ -226,6 +262,7 @@ def get_overall_summary(limit=10): conn.close() return results + def get_subreddit_summary(subreddit_name, limit=10): """ Gets the top tickers for a specific subreddit from the LAST 24 HOURS. @@ -233,7 +270,7 @@ def get_subreddit_summary(subreddit_name, limit=10): conn = get_db_connection() one_day_ago = datetime.now(timezone.utc) - timedelta(days=1) one_day_ago_timestamp = int(one_day_ago.timestamp()) - + query = """ SELECT t.symbol, t.market_cap, t.closing_price, COUNT(m.id) as mention_count, SUM(CASE WHEN m.mention_sentiment > 0.1 THEN 1 ELSE 0 END) as bullish_mentions, @@ -244,12 +281,15 @@ def get_subreddit_summary(subreddit_name, limit=10): GROUP BY t.symbol, t.market_cap, t.closing_price ORDER BY mention_count DESC LIMIT ?; """ - results = conn.execute(query, (subreddit_name, one_day_ago_timestamp, limit)).fetchall() + results = conn.execute( + query, (subreddit_name, one_day_ago_timestamp, limit) + ).fetchall() conn.close() return results + def get_daily_summary_for_subreddit(subreddit_name): - """ Gets a summary for the DAILY image view (last 24 hours). """ + """Gets a summary for the DAILY image view (last 24 hours).""" conn = get_db_connection() one_day_ago = datetime.now(timezone.utc) - timedelta(days=1) one_day_ago_timestamp = int(one_day_ago.timestamp()) @@ -268,8 +308,9 @@ def get_daily_summary_for_subreddit(subreddit_name): conn.close() return results + def get_weekly_summary_for_subreddit(subreddit_name, for_date): - """ Gets a summary for the WEEKLY image view (full week). """ + """Gets a summary for the WEEKLY image view (full week).""" conn = get_db_connection() start_of_week, end_of_week = get_week_start_end(for_date) start_timestamp = int(start_of_week.timestamp()) @@ -285,10 +326,13 @@ def get_weekly_summary_for_subreddit(subreddit_name, for_date): GROUP BY t.symbol, t.market_cap, t.closing_price ORDER BY total_mentions DESC LIMIT 10; """ - results = conn.execute(query, (subreddit_name, start_timestamp, end_timestamp)).fetchall() + results = conn.execute( + query, (subreddit_name, start_timestamp, end_timestamp) + ).fetchall() conn.close() return results, start_of_week, end_of_week + def get_overall_image_view_summary(): """ Gets a summary of top tickers across ALL subreddits for the DAILY image view (last 24 hours). @@ -311,6 +355,7 @@ def get_overall_image_view_summary(): conn.close() return results + def get_overall_daily_summary(): """ Gets the top tickers across all subreddits from the LAST 24 HOURS. @@ -332,13 +377,16 @@ def get_overall_daily_summary(): conn.close() return results + def get_overall_weekly_summary(): """ Gets the top tickers across all subreddits for the LAST 7 DAYS. """ conn = get_db_connection() today = datetime.now(timezone.utc) - start_of_week, end_of_week = get_week_start_end(today - timedelta(days=7)) # Get last week's boundaries + start_of_week, end_of_week = get_week_start_end( + today - timedelta(days=7) + ) # Get last week's boundaries start_timestamp = int(start_of_week.timestamp()) end_timestamp = int(end_of_week.timestamp()) query = """ @@ -354,8 +402,9 @@ def get_overall_weekly_summary(): conn.close() return results, start_of_week, end_of_week + def get_deep_dive_details(ticker_symbol): - """ Gets all analyzed posts that mention a specific ticker. """ + """Gets all analyzed posts that mention a specific ticker.""" conn = get_db_connection() query = """ SELECT DISTINCT p.*, s.name as subreddit_name FROM posts p @@ -367,12 +416,16 @@ def get_deep_dive_details(ticker_symbol): conn.close() return results + def get_all_scanned_subreddits(): - """ Gets a unique list of all subreddits we have data for. """ + """Gets a unique list of all subreddits we have data for.""" conn = get_db_connection() - results = conn.execute("SELECT DISTINCT name FROM subreddits ORDER BY name ASC;").fetchall() + results = conn.execute( + "SELECT DISTINCT name FROM subreddits ORDER BY name ASC;" + ).fetchall() conn.close() - return [row['name'] for row in results] + return [row["name"] for row in results] + def get_all_tickers(): """Retrieves the ID and symbol of every ticker in the database.""" @@ -381,6 +434,7 @@ def get_all_tickers(): conn.close() return results + def get_ticker_by_symbol(symbol): """ Retrieves a single ticker's ID and symbol from the database. @@ -388,11 +442,14 @@ def get_ticker_by_symbol(symbol): """ conn = get_db_connection() cursor = conn.cursor() - cursor.execute("SELECT id, symbol FROM tickers WHERE LOWER(symbol) = LOWER(?)", (symbol,)) + cursor.execute( + "SELECT id, symbol FROM tickers WHERE LOWER(symbol) = LOWER(?)", (symbol,) + ) result = cursor.fetchone() conn.close() return result + def get_top_daily_ticker_symbols(): """Gets a simple list of the Top 10 ticker symbols from the last 24 hours.""" conn = get_db_connection() @@ -405,7 +462,8 @@ def get_top_daily_ticker_symbols(): """ results = conn.execute(query, (one_day_ago_timestamp,)).fetchall() conn.close() - return [row['symbol'] for row in results] # Return a simple list of strings + return [row["symbol"] for row in results] # Return a simple list of strings + def get_top_weekly_ticker_symbols(): """Gets a simple list of the Top 10 ticker symbols from the last 7 days.""" @@ -419,7 +477,8 @@ def get_top_weekly_ticker_symbols(): """ results = conn.execute(query, (seven_days_ago_timestamp,)).fetchall() conn.close() - return [row['symbol'] for row in results] # Return a simple list of strings + return [row["symbol"] for row in results] # Return a simple list of strings + def get_top_daily_ticker_symbols_for_subreddit(subreddit_name): """Gets a list of the Top 10 daily ticker symbols for a specific subreddit.""" @@ -432,9 +491,16 @@ def get_top_daily_ticker_symbols_for_subreddit(subreddit_name): WHERE LOWER(s.name) = LOWER(?) AND m.mention_timestamp >= ? GROUP BY t.symbol ORDER BY COUNT(m.id) DESC LIMIT 10; """ - results = conn.execute(query, (subreddit_name, one_day_ago_timestamp,)).fetchall() + results = conn.execute( + query, + ( + subreddit_name, + one_day_ago_timestamp, + ), + ).fetchall() conn.close() - return [row['symbol'] for row in results] + return [row["symbol"] for row in results] + def get_top_weekly_ticker_symbols_for_subreddit(subreddit_name): """Gets a list of the Top 10 weekly ticker symbols for a specific subreddit.""" @@ -447,6 +513,12 @@ def get_top_weekly_ticker_symbols_for_subreddit(subreddit_name): WHERE LOWER(s.name) = LOWER(?) AND m.mention_timestamp >= ? GROUP BY t.symbol ORDER BY COUNT(m.id) DESC LIMIT 10; """ - results = conn.execute(query, (subreddit_name, seven_days_ago_timestamp,)).fetchall() + results = conn.execute( + query, + ( + subreddit_name, + seven_days_ago_timestamp, + ), + ).fetchall() conn.close() - return [row['symbol'] for row in results] \ No newline at end of file + return [row["symbol"] for row in results] diff --git a/rstat_tool/format_blacklist.py b/rstat_tool/format_blacklist.py index 32bc0dd..ea6338a 100644 --- a/rstat_tool/format_blacklist.py +++ b/rstat_tool/format_blacklist.py @@ -115,7 +115,7 @@ COMMON_WORDS_BLACKLIST = { def format_and_print_list(word_set, words_per_line=10): """ Sorts a set of words and prints it in a specific format. - + Args: word_set (set): The set of words to process. words_per_line (int): The number of words to print on each line. @@ -123,32 +123,33 @@ def format_and_print_list(word_set, words_per_line=10): # 1. Convert the set to a list to ensure order, and sort it alphabetically. # The set is also used to remove any duplicates from the initial list. sorted_words = sorted(list(word_set)) - + # 2. Start printing the output print("COMMON_WORDS_BLACKLIST = {") - + # 3. Iterate through the sorted list and print words, respecting the line limit for i in range(0, len(sorted_words), words_per_line): # Get a chunk of words for the current line - line_chunk = sorted_words[i:i + words_per_line] - + line_chunk = sorted_words[i : i + words_per_line] + # Format each word with double quotes formatted_words = [f'"{word}"' for word in line_chunk] - + # Join the words with a comma and a space line_content = ", ".join(formatted_words) - + # Add a trailing comma if it's not the last line is_last_line = (i + words_per_line) >= len(sorted_words) if not is_last_line: line_content += "," - + # Print the indented line print(f" {line_content}") # 4. Print the closing brace print("}") + # --- Main execution --- if __name__ == "__main__": - format_and_print_list(COMMON_WORDS_BLACKLIST) \ No newline at end of file + format_and_print_list(COMMON_WORDS_BLACKLIST) diff --git a/rstat_tool/logger_setup.py b/rstat_tool/logger_setup.py index 22a51f8..3b0a1ab 100644 --- a/rstat_tool/logger_setup.py +++ b/rstat_tool/logger_setup.py @@ -5,6 +5,7 @@ import sys logger = logging.getLogger("rstat_app") + def setup_logging(console_verbose=False, debug_mode=False): """ Configures the application's logger with a new DEBUG level. @@ -12,30 +13,32 @@ def setup_logging(console_verbose=False, debug_mode=False): # The logger itself must be set to the lowest possible level (DEBUG). log_level = logging.DEBUG if debug_mode else logging.INFO logger.setLevel(log_level) - + logger.propagate = False if logger.hasHandlers(): logger.handlers.clear() # File Handler (Always verbose at INFO level or higher) - file_handler = logging.FileHandler("rstat.log", mode='a') - file_handler.setLevel(logging.INFO) # We don't need debug spam in the file usually - file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') + file_handler = logging.FileHandler("rstat.log", mode="a") + file_handler.setLevel(logging.INFO) # We don't need debug spam in the file usually + file_formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) file_handler.setFormatter(file_formatter) logger.addHandler(file_handler) # Console Handler (Verbosity is controlled) console_handler = logging.StreamHandler(sys.stdout) - console_formatter = logging.Formatter('%(message)s') + console_formatter = logging.Formatter("%(message)s") console_handler.setFormatter(console_formatter) - + if debug_mode: console_handler.setLevel(logging.DEBUG) elif console_verbose: console_handler.setLevel(logging.INFO) else: console_handler.setLevel(logging.CRITICAL) - + logger.addHandler(console_handler) # YFINANCE LOGGER CAPTURE @@ -45,4 +48,4 @@ def setup_logging(console_verbose=False, debug_mode=False): yfinance_logger.handlers.clear() yfinance_logger.setLevel(logging.WARNING) yfinance_logger.addHandler(console_handler) - yfinance_logger.addHandler(file_handler) \ No newline at end of file + yfinance_logger.addHandler(file_handler) diff --git a/rstat_tool/main.py b/rstat_tool/main.py index b375190..238c579 100644 --- a/rstat_tool/main.py +++ b/rstat_tool/main.py @@ -16,27 +16,32 @@ from .ticker_extractor import extract_tickers from .sentiment_analyzer import get_sentiment_score from .logger_setup import setup_logging, logger as log + def load_subreddits(filepath): """Loads a list of subreddits from a JSON file.""" try: - with open(filepath, 'r') as f: + with open(filepath, "r") as f: return json.load(f).get("subreddits", []) except (FileNotFoundError, json.JSONDecodeError) as e: log.error(f"Error loading config file '{filepath}': {e}") return None + def get_reddit_instance(): """Initializes and returns a PRAW Reddit instance.""" - env_path = Path(__file__).parent.parent / '.env' + env_path = Path(__file__).parent.parent / ".env" load_dotenv(dotenv_path=env_path) - + client_id = os.getenv("REDDIT_CLIENT_ID") client_secret = os.getenv("REDDIT_CLIENT_SECRET") user_agent = os.getenv("REDDIT_USER_AGENT") if not all([client_id, client_secret, user_agent]): log.error("Error: Reddit API credentials not found in .env file.") return None - return praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent) + return praw.Reddit( + client_id=client_id, client_secret=client_secret, user_agent=user_agent + ) + def get_financial_data_via_fetcher(ticker_symbol): """ @@ -48,38 +53,45 @@ def get_financial_data_via_fetcher(ticker_symbol): # --- Call 1: Get Market Cap --- try: - mc_script_path = project_root / 'fetch_market_cap.py' + mc_script_path = project_root / "fetch_market_cap.py" command_mc = [sys.executable, str(mc_script_path), ticker_symbol] - result_mc = subprocess.run(command_mc, capture_output=True, text=True, check=True, timeout=30) + result_mc = subprocess.run( + command_mc, capture_output=True, text=True, check=True, timeout=30 + ) financials.update(json.loads(result_mc.stdout)) except Exception as e: log.warning(f"Market cap fetcher failed for {ticker_symbol}: {e}") # --- Call 2: Get Closing Price --- try: - cp_script_path = project_root / 'fetch_close_price.py' + cp_script_path = project_root / "fetch_close_price.py" command_cp = [sys.executable, str(cp_script_path), ticker_symbol] - result_cp = subprocess.run(command_cp, capture_output=True, text=True, check=True, timeout=30) + result_cp = subprocess.run( + command_cp, capture_output=True, text=True, check=True, timeout=30 + ) financials.update(json.loads(result_cp.stdout)) except Exception as e: log.warning(f"Closing price fetcher failed for {ticker_symbol}: {e}") - + return financials + # --- HELPER FUNCTION: Contains all the optimized logic for one post --- -def _process_submission(submission, subreddit_id, conn, comment_limit, fetch_financials): +def _process_submission( + submission, subreddit_id, conn, comment_limit, fetch_financials +): """ Processes a single Reddit submission with optimized logic. - Uses a single loop over comments. - Caches ticker IDs to reduce DB lookups. """ current_time = time.time() - + # 1. Initialize data collectors for this post tickers_in_title = set(extract_tickers(submission.title)) all_tickers_found_in_post = set(tickers_in_title) all_comment_sentiments = [] - ticker_id_cache = {} # In-memory cache for ticker IDs for this post + ticker_id_cache = {} # In-memory cache for ticker IDs for this post submission.comments.replace_more(limit=0) all_comments = submission.comments.list()[:comment_limit] @@ -88,8 +100,8 @@ def _process_submission(submission, subreddit_id, conn, comment_limit, fetch_fin # We gather all necessary information in one pass. for comment in all_comments: comment_sentiment = get_sentiment_score(comment.body) - all_comment_sentiments.append(comment_sentiment) # For the deep dive - + all_comment_sentiments.append(comment_sentiment) # For the deep dive + tickers_in_comment = set(extract_tickers(comment.body)) if not tickers_in_comment: continue @@ -101,147 +113,266 @@ def _process_submission(submission, subreddit_id, conn, comment_limit, fetch_fin # If the title has tickers, every comment is a mention for them for ticker_symbol in tickers_in_title: if ticker_symbol not in ticker_id_cache: - ticker_id_cache[ticker_symbol] = database.get_or_create_entity(conn, 'tickers', 'symbol', ticker_symbol) + ticker_id_cache[ticker_symbol] = database.get_or_create_entity( + conn, "tickers", "symbol", ticker_symbol + ) ticker_id = ticker_id_cache[ticker_symbol] - database.add_mention(conn, ticker_id, subreddit_id, submission.id, 'comment', int(comment.created_utc), comment_sentiment) + database.add_mention( + conn, + ticker_id, + subreddit_id, + submission.id, + "comment", + int(comment.created_utc), + comment_sentiment, + ) else: # If no title tickers, only direct mentions in comments count for ticker_symbol in tickers_in_comment: if ticker_symbol not in ticker_id_cache: - ticker_id_cache[ticker_symbol] = database.get_or_create_entity(conn, 'tickers', 'symbol', ticker_symbol) + ticker_id_cache[ticker_symbol] = database.get_or_create_entity( + conn, "tickers", "symbol", ticker_symbol + ) ticker_id = ticker_id_cache[ticker_symbol] - database.add_mention(conn, ticker_id, subreddit_id, submission.id, 'comment', int(comment.created_utc), comment_sentiment) + database.add_mention( + conn, + ticker_id, + subreddit_id, + submission.id, + "comment", + int(comment.created_utc), + comment_sentiment, + ) # 3. Process title mentions (if any) if tickers_in_title: - log.info(f" -> Title Mention(s): {', '.join(tickers_in_title)}. Attributing all comments.") + log.info( + f" -> Title Mention(s): {', '.join(tickers_in_title)}. Attributing all comments." + ) post_sentiment = get_sentiment_score(submission.title) for ticker_symbol in tickers_in_title: if ticker_symbol not in ticker_id_cache: - ticker_id_cache[ticker_symbol] = database.get_or_create_entity(conn, 'tickers', 'symbol', ticker_symbol) + ticker_id_cache[ticker_symbol] = database.get_or_create_entity( + conn, "tickers", "symbol", ticker_symbol + ) ticker_id = ticker_id_cache[ticker_symbol] - database.add_mention(conn, ticker_id, subreddit_id, submission.id, 'post', int(submission.created_utc), post_sentiment) + database.add_mention( + conn, + ticker_id, + subreddit_id, + submission.id, + "post", + int(submission.created_utc), + post_sentiment, + ) # 4. Fetch financial data if enabled if fetch_financials: for ticker_symbol in all_tickers_found_in_post: - ticker_id = ticker_id_cache[ticker_symbol] # Guaranteed to be in cache + ticker_id = ticker_id_cache[ticker_symbol] # Guaranteed to be in cache ticker_info = database.get_ticker_info(conn, ticker_id) - if not ticker_info['last_updated'] or (current_time - ticker_info['last_updated'] > database.MARKET_CAP_REFRESH_INTERVAL): + if not ticker_info["last_updated"] or ( + current_time - ticker_info["last_updated"] + > database.MARKET_CAP_REFRESH_INTERVAL + ): log.info(f" -> Fetching financial data for {ticker_symbol}...") financials = get_financial_data_via_fetcher(ticker_symbol) - database.update_ticker_financials(conn, ticker_id, financials.get('market_cap'), financials.get('closing_price')) - + database.update_ticker_financials( + conn, + ticker_id, + financials.get("market_cap"), + financials.get("closing_price"), + ) + # 5. Save deep dive analysis - avg_sentiment = sum(all_comment_sentiments) / len(all_comment_sentiments) if all_comment_sentiments else 0 + avg_sentiment = ( + sum(all_comment_sentiments) / len(all_comment_sentiments) + if all_comment_sentiments + else 0 + ) post_analysis_data = { - "post_id": submission.id, "title": submission.title, - "post_url": f"https://reddit.com{submission.permalink}", "subreddit_id": subreddit_id, - "post_timestamp": int(submission.created_utc), "comment_count": len(all_comments), - "avg_comment_sentiment": avg_sentiment + "post_id": submission.id, + "title": submission.title, + "post_url": f"https://reddit.com{submission.permalink}", + "subreddit_id": subreddit_id, + "post_timestamp": int(submission.created_utc), + "comment_count": len(all_comments), + "avg_comment_sentiment": avg_sentiment, } database.add_or_update_post_analysis(conn, post_analysis_data) -def scan_subreddits(reddit, subreddits_list, post_limit=100, comment_limit=100, days_to_scan=1, fetch_financials=True): + +def scan_subreddits( + reddit, + subreddits_list, + post_limit=100, + comment_limit=100, + days_to_scan=1, + fetch_financials=True, +): conn = database.get_db_connection() post_age_limit = days_to_scan * 86400 current_time = time.time() - - log.info(f"Scanning {len(subreddits_list)} subreddit(s) for NEW posts in the last {days_to_scan} day(s)...") + + log.info( + f"Scanning {len(subreddits_list)} subreddit(s) for NEW posts in the last {days_to_scan} day(s)..." + ) if not fetch_financials: log.warning("NOTE: Financial data fetching is disabled for this run.") for subreddit_name in subreddits_list: try: normalized_sub_name = subreddit_name.lower() - subreddit_id = database.get_or_create_entity(conn, 'subreddits', 'name', normalized_sub_name) + subreddit_id = database.get_or_create_entity( + conn, "subreddits", "name", normalized_sub_name + ) subreddit = reddit.subreddit(normalized_sub_name) log.info(f"Scanning r/{normalized_sub_name}...") - + for submission in subreddit.new(limit=post_limit): if (current_time - submission.created_utc) > post_age_limit: - log.info(f" -> Reached posts older than the {days_to_scan}-day limit.") + log.info( + f" -> Reached posts older than the {days_to_scan}-day limit." + ) break - + # Call the new helper function for each post - _process_submission(submission, subreddit_id, conn, comment_limit, fetch_financials) - + _process_submission( + submission, subreddit_id, conn, comment_limit, fetch_financials + ) + except Exception as e: - log.error(f"Could not scan r/{normalized_sub_name}. Error: {e}", exc_info=True) - + log.error( + f"Could not scan r/{normalized_sub_name}. Error: {e}", exc_info=True + ) + conn.close() log.critical("\n--- Scan Complete ---") + def main(): """Main function to run the Reddit stock analysis tool.""" - parser = argparse.ArgumentParser(description="Analyze stock ticker mentions on Reddit.", formatter_class=argparse.RawTextHelpFormatter) - - parser.add_argument("-f", "--config", default="subreddits.json", help="Path to the JSON file for scanning. (Default: subreddits.json)") - parser.add_argument("-s", "--subreddit", help="Scan a single subreddit, ignoring the config file.") - parser.add_argument("-d", "--days", type=int, default=1, help="Number of past days to scan for new posts. (Default: 1)") - parser.add_argument("-p", "--posts", type=int, default=200, help="Max posts to check per subreddit. (Default: 200)") - parser.add_argument("-c", "--comments", type=int, default=100, help="Number of comments to scan per post. (Default: 100)") - parser.add_argument("-n", "--no-financials", action="store_true", help="Disable fetching of financial data during the Reddit scan.") - parser.add_argument("--update-top-tickers", action="store_true", help="Update financial data only for tickers currently in the Top 10 daily/weekly dashboards.") - parser.add_argument( - "-u", "--update-financials-only", - nargs='?', - const="ALL_TICKERS", # A special value to signify "update all" - default=None, - metavar='TICKER', - help="Update financials. Provide a ticker symbol to update just one,\nor use the flag alone to update all tickers in the database." + parser = argparse.ArgumentParser( + description="Analyze stock ticker mentions on Reddit.", + formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument("--debug", action="store_true", help="Enable detailed debug logging to the console.") - parser.add_argument("--stdout", action="store_true", help="Print all log messages to the console.") - + + parser.add_argument( + "-f", + "--config", + default="subreddits.json", + help="Path to the JSON file for scanning. (Default: subreddits.json)", + ) + parser.add_argument( + "-s", "--subreddit", help="Scan a single subreddit, ignoring the config file." + ) + parser.add_argument( + "-d", + "--days", + type=int, + default=1, + help="Number of past days to scan for new posts. (Default: 1)", + ) + parser.add_argument( + "-p", + "--posts", + type=int, + default=200, + help="Max posts to check per subreddit. (Default: 200)", + ) + parser.add_argument( + "-c", + "--comments", + type=int, + default=100, + help="Number of comments to scan per post. (Default: 100)", + ) + parser.add_argument( + "-n", + "--no-financials", + action="store_true", + help="Disable fetching of financial data during the Reddit scan.", + ) + parser.add_argument( + "--update-top-tickers", + action="store_true", + help="Update financial data only for tickers currently in the Top 10 daily/weekly dashboards.", + ) + parser.add_argument( + "-u", + "--update-financials-only", + nargs="?", + const="ALL_TICKERS", # A special value to signify "update all" + default=None, + metavar="TICKER", + help="Update financials. Provide a ticker symbol to update just one,\nor use the flag alone to update all tickers in the database.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Enable detailed debug logging to the console.", + ) + parser.add_argument( + "--stdout", action="store_true", help="Print all log messages to the console." + ) + args = parser.parse_args() setup_logging(console_verbose=args.stdout, debug_mode=args.debug) database.initialize_db() - + if args.update_top_tickers: log.critical("--- Starting Financial Data Update for Top Tickers ---") - + # 1. Start with an empty set to hold all unique tickers tickers_to_update = set() - + # 2. Get the overall top tickers log.info("-> Checking overall top daily and weekly tickers...") top_daily_overall = database.get_top_daily_ticker_symbols() top_weekly_overall = database.get_top_weekly_ticker_symbols() tickers_to_update.update(top_daily_overall) tickers_to_update.update(top_weekly_overall) - + # 3. Get all subreddits and loop through them all_subreddits = database.get_all_scanned_subreddits() - log.info(f"-> Checking top tickers for {len(all_subreddits)} individual subreddit(s)...") + log.info( + f"-> Checking top tickers for {len(all_subreddits)} individual subreddit(s)..." + ) for sub_name in all_subreddits: log.debug(f" -> Checking r/{sub_name}...") - top_daily_sub = database.get_top_daily_ticker_symbols_for_subreddit(sub_name) - top_weekly_sub = database.get_top_weekly_ticker_symbols_for_subreddit(sub_name) + top_daily_sub = database.get_top_daily_ticker_symbols_for_subreddit( + sub_name + ) + top_weekly_sub = database.get_top_weekly_ticker_symbols_for_subreddit( + sub_name + ) tickers_to_update.update(top_daily_sub) tickers_to_update.update(top_weekly_sub) unique_top_tickers = sorted(list(tickers_to_update)) - + if not unique_top_tickers: log.info("No top tickers found in the last week. Nothing to update.") else: - log.info(f"Found {len(unique_top_tickers)} unique top tickers to update: {', '.join(unique_top_tickers)}") + log.info( + f"Found {len(unique_top_tickers)} unique top tickers to update: {', '.join(unique_top_tickers)}" + ) conn = database.get_db_connection() for ticker_symbol in unique_top_tickers: # 4. Find the ticker's ID to perform the update ticker_info = database.get_ticker_by_symbol(ticker_symbol) if ticker_info: log.info(f" -> Updating financials for {ticker_info['symbol']}...") - financials = get_financial_data_via_fetcher(ticker_info['symbol']) + financials = get_financial_data_via_fetcher(ticker_info["symbol"]) database.update_ticker_financials( - conn, ticker_info['id'], - financials.get('market_cap'), - financials.get('closing_price') + conn, + ticker_info["id"], + financials.get("market_cap"), + financials.get("closing_price"), ) conn.close() - + log.critical("--- Top Ticker Financial Data Update Complete ---") elif args.update_financials_only: @@ -253,31 +384,37 @@ def main(): log.info(f"Found {len(all_tickers)} tickers in the database to update.") conn = database.get_db_connection() for ticker in all_tickers: - symbol = ticker['symbol'] + symbol = ticker["symbol"] log.info(f" -> Updating financials for {symbol}...") financials = get_financial_data_via_fetcher(symbol) database.update_ticker_financials( - conn, ticker['id'], - financials.get('market_cap'), - financials.get('closing_price') + conn, + ticker["id"], + financials.get("market_cap"), + financials.get("closing_price"), ) conn.close() else: ticker_symbol_to_update = update_mode - log.critical(f"--- Starting Financial Data Update for single ticker: {ticker_symbol_to_update} ---") + log.critical( + f"--- Starting Financial Data Update for single ticker: {ticker_symbol_to_update} ---" + ) ticker_info = database.get_ticker_by_symbol(ticker_symbol_to_update) if ticker_info: conn = database.get_db_connection() log.info(f" -> Updating financials for {ticker_info['symbol']}...") - financials = get_financial_data_via_fetcher(ticker_info['symbol']) + financials = get_financial_data_via_fetcher(ticker_info["symbol"]) database.update_ticker_financials( - conn, ticker_info['id'], - financials.get('market_cap'), - financials.get('closing_price') + conn, + ticker_info["id"], + financials.get("market_cap"), + financials.get("closing_price"), ) conn.close() else: - log.error(f"Ticker '{ticker_symbol_to_update}' not found in the database.") + log.error( + f"Ticker '{ticker_symbol_to_update}' not found in the database." + ) log.critical("--- Financial Data Update Complete ---") else: @@ -288,14 +425,15 @@ def main(): log.info(f"Targeted Scan Mode: Focusing on r/{args.subreddit}") else: log.info(f"Config Scan Mode: Loading subreddits from {args.config}") - subreddits_to_scan = load_subreddits(args.config) + subreddits_to_scan = load_subreddits(args.config) if not subreddits_to_scan: log.error("Error: No subreddits to scan.") return reddit = get_reddit_instance() - if not reddit: return + if not reddit: + return scan_subreddits( reddit, @@ -303,8 +441,9 @@ def main(): post_limit=args.posts, comment_limit=args.comments, days_to_scan=args.days, - fetch_financials=(not args.no_financials) + fetch_financials=(not args.no_financials), ) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/rstat_tool/sentiment_analyzer.py b/rstat_tool/sentiment_analyzer.py index 32b08e8..c21ae24 100644 --- a/rstat_tool/sentiment_analyzer.py +++ b/rstat_tool/sentiment_analyzer.py @@ -9,11 +9,11 @@ _analyzer = SentimentIntensityAnalyzer() def get_sentiment_score(text): """ Analyzes a piece of text and returns its sentiment score. - + The 'compound' score is a single metric that summarizes the sentiment. It ranges from -1 (most negative) to +1 (most positive). """ # The polarity_scores() method returns a dictionary with 'neg', 'neu', 'pos', and 'compound' scores. # We are most interested in the 'compound' score. scores = _analyzer.polarity_scores(text) - return scores['compound'] \ No newline at end of file + return scores["compound"] diff --git a/rstat_tool/setup_nltk.py b/rstat_tool/setup_nltk.py index bfd5209..98803af 100644 --- a/rstat_tool/setup_nltk.py +++ b/rstat_tool/setup_nltk.py @@ -3,9 +3,9 @@ import nltk # This will download the 'vader_lexicon' dataset # It only needs to be run once try: - nltk.data.find('sentiment/vader_lexicon.zip') + nltk.data.find("sentiment/vader_lexicon.zip") print("VADER lexicon is already downloaded.") except LookupError: print("Downloading VADER lexicon...") - nltk.download('vader_lexicon') - print("Download complete.") \ No newline at end of file + nltk.download("vader_lexicon") + print("Download complete.") diff --git a/rstat_tool/ticker_extractor.py b/rstat_tool/ticker_extractor.py index 0fa7747..6df2137 100644 --- a/rstat_tool/ticker_extractor.py +++ b/rstat_tool/ticker_extractor.py @@ -135,4 +135,4 @@ def extract_tickers(text): if cleaned_ticker not in COMMON_WORDS_BLACKLIST: tickers.append(cleaned_ticker) - return tickers \ No newline at end of file + return tickers diff --git a/setup.py b/setup.py index f764442..7b2df8c 100644 --- a/setup.py +++ b/setup.py @@ -2,24 +2,24 @@ from setuptools import setup, find_packages -with open('requirements.txt') as f: +with open("requirements.txt") as f: requirements = f.read().splitlines() setup( - name='reddit-stock-analyzer', - version='0.0.1', - author='Pål-Kristian Hamre', - author_email='its@pkhamre.com', - description='A command-line tool to analyze stock ticker mentions on Reddit.', + name="reddit-stock-analyzer", + version="0.0.1", + author="Pål-Kristian Hamre", + author_email="its@pkhamre.com", + description="A command-line tool to analyze stock ticker mentions on Reddit.", # This now correctly finds your 'rstat_tool' package - packages=find_packages(), + packages=find_packages(), install_requires=requirements, entry_points={ - 'console_scripts': [ + "console_scripts": [ # The path is now 'package_name.module_name:function_name' - 'rstat=rstat_tool.main:main', - 'rstat-dashboard=rstat_tool.dashboard:start_dashboard', - 'rstat-cleanup=rstat_tool.cleanup:run_cleanup', + "rstat=rstat_tool.main:main", + "rstat-dashboard=rstat_tool.dashboard:start_dashboard", + "rstat-cleanup=rstat_tool.cleanup:run_cleanup", ], }, -) \ No newline at end of file +) diff --git a/templates/dashboard_base.html b/templates/dashboard_base.html index eb4a4b9..304037f 100644 --- a/templates/dashboard_base.html +++ b/templates/dashboard_base.html @@ -1,5 +1,6 @@ + @@ -8,21 +9,64 @@ + {% if not is_image_mode %} {% endif %} @@ -149,4 +349,5 @@ {% block content %}{% endblock %} + \ No newline at end of file diff --git a/templates/dashboard_view.html b/templates/dashboard_view.html index 3b3a3c3..d85502b 100644 --- a/templates/dashboard_view.html +++ b/templates/dashboard_view.html @@ -12,11 +12,12 @@
{{ date_string }}
- + {% if not is_image_mode %} - + @@ -42,28 +43,28 @@ {% if is_image_mode %} - {{ ticker.symbol }} + {{ ticker.symbol }} {% else %} - {{ ticker.symbol }} + {{ ticker.symbol }} {% endif %} {{ ticker.total_mentions }} {% if ticker.bullish_mentions > ticker.bearish_mentions %} - Bullish + Bullish {% elif ticker.bearish_mentions > ticker.bullish_mentions %} - Bearish + Bearish {% else %} - Neutral + Neutral {% endif %} {{ ticker.market_cap | format_mc }} {% if ticker.closing_price %} - ${{ "%.2f"|format(ticker.closing_price) }} + ${{ "%.2f"|format(ticker.closing_price) }} {% else %} - N/A + N/A {% endif %} diff --git a/templates/deep_dive.html b/templates/deep_dive.html index 5ccde64..7f45a37 100644 --- a/templates/deep_dive.html +++ b/templates/deep_dive.html @@ -3,27 +3,27 @@ {% block title %}Deep Dive: {{ symbol }}{% endblock %} {% block content %} -

Deep Dive Analysis for: {{ symbol }}

-

Showing posts that mention {{ symbol }}, sorted by most recent.

+

Deep Dive Analysis for: {{ symbol }}

+

Showing posts that mention {{ symbol }}, sorted by most recent.

- {% for post in posts %} -
-

{{ post.title }}

- -
- {% else %} -

No analyzed posts found for this ticker. Run the 'rstat' scraper to gather data.

- {% endfor %} +{% for post in posts %} +
+

{{ post.title }}

+ +
+{% else %} +

No analyzed posts found for this ticker. Run the 'rstat' scraper to gather data.

+{% endfor %} {% endblock %} \ No newline at end of file diff --git a/yfinance_test.py b/yfinance_test.py index cce707a..7acebf1 100644 --- a/yfinance_test.py +++ b/yfinance_test.py @@ -6,8 +6,7 @@ import logging # Set up a simple logger to see detailed error tracebacks logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) # A list of tickers to test. One very common one, and two from your logs. @@ -20,31 +19,41 @@ for ticker_symbol in TICKERS_TO_TEST: # --- Test 1: The Ticker().info method --- try: - logging.info(f"Attempting to create Ticker object and get .info for {ticker_symbol}...") + logging.info( + f"Attempting to create Ticker object and get .info for {ticker_symbol}..." + ) ticker_obj = yf.Ticker(ticker_symbol) - market_cap = ticker_obj.info.get('marketCap') + market_cap = ticker_obj.info.get("marketCap") if market_cap is not None: logging.info(f"SUCCESS: Got market cap for {ticker_symbol}: {market_cap}") else: - logging.warning(f"PARTIAL SUCCESS: .info call for {ticker_symbol} worked, but no market cap was found.") + logging.warning( + f"PARTIAL SUCCESS: .info call for {ticker_symbol} worked, but no market cap was found." + ) except Exception: - logging.error(f"FAILURE: An error occurred during the Ticker().info call for {ticker_symbol}.", exc_info=True) - + logging.error( + f"FAILURE: An error occurred during the Ticker().info call for {ticker_symbol}.", + exc_info=True, + ) # --- Test 2: The yf.download() method --- try: logging.info(f"Attempting yf.download() for {ticker_symbol}...") data = yf.download( - ticker_symbol, - period="2d", - progress=False, - auto_adjust=False + ticker_symbol, period="2d", progress=False, auto_adjust=False ) if not data.empty: - logging.info(f"SUCCESS: yf.download() for {ticker_symbol} returned {len(data)} rows of data.") + logging.info( + f"SUCCESS: yf.download() for {ticker_symbol} returned {len(data)} rows of data." + ) else: - logging.warning(f"PARTIAL SUCCESS: yf.download() for {ticker_symbol} worked, but returned no data (likely delisted).") + logging.warning( + f"PARTIAL SUCCESS: yf.download() for {ticker_symbol} worked, but returned no data (likely delisted)." + ) except Exception: - logging.error(f"FAILURE: An error occurred during the yf.download() call for {ticker_symbol}.", exc_info=True) + logging.error( + f"FAILURE: An error occurred during the yf.download() call for {ticker_symbol}.", + exc_info=True, + ) -print("\n--- YFINANCE Diagnostic Test Complete ---") \ No newline at end of file +print("\n--- YFINANCE Diagnostic Test Complete ---")