reddit_scraper/main.py

209 lines
6.7 KiB
Python
Executable File

"""FastAPI application entry point with CLI argument parsing."""
import argparse
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager
from config import Config
from scraper.selenium_scrapers import get_scraper
from models import SubredditQuery, PostQuery, CustomQuery
# Global scraper instance and lifespan management
scraper = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize and cleanup resources."""
global scraper
scraper = get_scraper()
yield
# Cleanup can be added here if needed
app = FastAPI(
title="Reddit Super Duper Scraper",
description="A powerful tool to scrape public Reddit data without authentication. Accessible via local network only.",
version="1.0.0",
lifespan=lifespan
)
@app.exception_handler(HTTPException)
async def custom_http_exception_handler(request: Request, exc: HTTPException):
"""Custom HTTP exception handler with friendly error messages."""
error_messages = {
400: "The boat went on fire (Bad Request)",
404: "The boat went on fire (Not Found)",
429: "The boat went on fire (Too Many Requests - Rate Limited)",
500: "The boat went on fire (Internal Server Error)"
}
message = error_messages.get(exc.status_code, f"The boat went on fire ({exc.detail})")
return JSONResponse(
status_code=exc.status_code,
content={"Error": message}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""Handle unexpected exceptions with user-friendly messages."""
return JSONResponse(
status_code=500,
content={"Error": "The boat went on fire (Unexpected Error)"}
)
@app.get("/health")
async def health_check():
"""Basic health check endpoint."""
return {"status": "healthy", "message": "The ship is sailing smoothly"}
@app.get("/scrape/subreddit/{subreddit}")
async def scrape_subreddit(
subreddit: str,
limit: int = Query(default=10, ge=1, le=100),
time_range: str = Query(default="week"),
depth: int = Query(default=1, ge=1, le=10),
include_comments: bool = Query(default=True)
):
"""
Scrape top posts from a subreddit with nested comments.
- **subreddit**: Name of the subreddit (without 'r/')
- **limit**: Number of top posts to retrieve (1-100)
- **time_range**: Time filter ('hour', 'day', 'week', 'month', 'year', 'all')
- **depth**: Maximum comment nesting depth (1-10)
- **include_comments**: Whether to scrape comments (True/False, default: True)
Returns post data with or without comment threads based on parameter.
Setting `include_comments=false` provides faster response times as it skips
the additional DOM traversal required for comment extraction.
"""
result = scraper.scrape_subreddit_top(
subreddit=subreddit,
limit=limit,
time_range=time_range,
depth=depth,
include_comments=include_comments
)
if "Error" in result:
raise HTTPException(status_code=500, detail=str(result["Error"]))
return result
@app.get("/scrape/post/{post_id}")
async def scrape_post(post_id: str, depth: int = Query(default=3, ge=1, le=10)):
"""
Scrape all comments from a specific Reddit post with nested replies.
- **post_id**: Reddit post ID (without 't3_')
- **depth**: Maximum comment nesting depth (1-10)
"""
result = scraper.scrape_post_comments(post_id=post_id, depth=depth)
if "Error" in result:
raise HTTPException(status_code=500, detail=str(result["Error"]))
return result
@app.post("/scrape/custom")
async def scrape_custom(query: CustomQuery):
"""
Flexible endpoint for custom scraping queries.
- **type**: Type of scrape ('subreddit' or 'post')
- **target**: Subreddit name or post ID
- **limit**: Number of posts (for subreddit type)
- **time_range**: Time filter (for subreddit type)
- **depth**: Maximum comment nesting depth
- **include_comments**: Whether to scrape comments (set False for faster results)
"""
if query.type == "subreddit":
result = scraper.scrape_subreddit_top(
subreddit=query.target,
limit=query.limit,
time_range=query.time_range,
depth=query.depth,
include_comments=query.include_comments
)
elif query.type == "post":
if not query.include_comments:
# Just fetch post metadata without comments for faster response
result = scraper.scrape_post_comments(query.target, depth=0)
if "Error" not in result and "data" in result:
# Return empty comments list since we're skipping them
result["data"] = []
else:
result = scraper.scrape_post_comments(
post_id=query.target,
depth=query.depth
)
if "Error" in result:
raise HTTPException(status_code=500, detail=str(result["Error"]))
return result
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"name": "Reddit Super Duper Scraper",
"version": "1.0.0",
"description": "Scrape public Reddit data without authentication",
"endpoints": {
"/scrape/subreddit/{subreddit}": "GET - Scrape top posts from a subreddit",
"/scrape/post/{post_id}": "GET - Scrape comments from a specific post",
"/scrape/custom": "POST - Flexible custom scraping query",
"/health": "GET - Health check endpoint"
},
"docs": "/docs",
"redoc": "/redoc"
}
def parse_args():
"""Parse command line arguments for server configuration."""
parser = argparse.ArgumentParser(
description="Reddit Super Duper Scraper - Scrape public Reddit data via local API"
)
parser.add_argument(
"--port",
type=int,
default=Config.DEFAULT_PORT,
help=f"Port to run the server on (default: {Config.DEFAULT_PORT})"
)
parser.add_argument(
"--host",
type=str,
default=Config.DEFAULT_HOST,
help=f"Host to bind to (default: {Config.DEFAULT_HOST})"
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
import uvicorn
print(f"🚀 Starting Reddit Super Duper Scraper on http://{args.host}:{args.port}")
print("📖 API documentation available at http://localhost:{}/docs".format(args.port))
print("💡 Accessible via local network only - no authentication required")
uvicorn.run(
"main:app",
host=args.host,
port=args.port,
reload=False
)