209 lines
6.7 KiB
Python
Executable File
209 lines
6.7 KiB
Python
Executable File
"""FastAPI application entry point with CLI argument parsing."""
|
|
|
|
import argparse
|
|
from fastapi import FastAPI, HTTPException, Query, Request
|
|
from fastapi.responses import JSONResponse
|
|
from contextlib import asynccontextmanager
|
|
|
|
from config import Config
|
|
from scraper.selenium_scrapers import get_scraper
|
|
from models import SubredditQuery, PostQuery, CustomQuery
|
|
|
|
|
|
# Global scraper instance and lifespan management
|
|
scraper = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Initialize and cleanup resources."""
|
|
global scraper
|
|
scraper = get_scraper()
|
|
yield
|
|
# Cleanup can be added here if needed
|
|
|
|
|
|
app = FastAPI(
|
|
title="Reddit Super Duper Scraper",
|
|
description="A powerful tool to scrape public Reddit data without authentication. Accessible via local network only.",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
|
|
@app.exception_handler(HTTPException)
|
|
async def custom_http_exception_handler(request: Request, exc: HTTPException):
|
|
"""Custom HTTP exception handler with friendly error messages."""
|
|
error_messages = {
|
|
400: "The boat went on fire (Bad Request)",
|
|
404: "The boat went on fire (Not Found)",
|
|
429: "The boat went on fire (Too Many Requests - Rate Limited)",
|
|
500: "The boat went on fire (Internal Server Error)"
|
|
}
|
|
|
|
message = error_messages.get(exc.status_code, f"The boat went on fire ({exc.detail})")
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={"Error": message}
|
|
)
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
async def general_exception_handler(request: Request, exc: Exception):
|
|
"""Handle unexpected exceptions with user-friendly messages."""
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"Error": "The boat went on fire (Unexpected Error)"}
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Basic health check endpoint."""
|
|
return {"status": "healthy", "message": "The ship is sailing smoothly"}
|
|
|
|
|
|
@app.get("/scrape/subreddit/{subreddit}")
|
|
async def scrape_subreddit(
|
|
subreddit: str,
|
|
limit: int = Query(default=10, ge=1, le=100),
|
|
time_range: str = Query(default="week"),
|
|
depth: int = Query(default=1, ge=1, le=10),
|
|
include_comments: bool = Query(default=True)
|
|
):
|
|
"""
|
|
Scrape top posts from a subreddit with nested comments.
|
|
|
|
- **subreddit**: Name of the subreddit (without 'r/')
|
|
- **limit**: Number of top posts to retrieve (1-100)
|
|
- **time_range**: Time filter ('hour', 'day', 'week', 'month', 'year', 'all')
|
|
- **depth**: Maximum comment nesting depth (1-10)
|
|
- **include_comments**: Whether to scrape comments (True/False, default: True)
|
|
|
|
Returns post data with or without comment threads based on parameter.
|
|
Setting `include_comments=false` provides faster response times as it skips
|
|
the additional DOM traversal required for comment extraction.
|
|
"""
|
|
result = scraper.scrape_subreddit_top(
|
|
subreddit=subreddit,
|
|
limit=limit,
|
|
time_range=time_range,
|
|
depth=depth,
|
|
include_comments=include_comments
|
|
)
|
|
|
|
if "Error" in result:
|
|
raise HTTPException(status_code=500, detail=str(result["Error"]))
|
|
|
|
return result
|
|
|
|
|
|
@app.get("/scrape/post/{post_id}")
|
|
async def scrape_post(post_id: str, depth: int = Query(default=3, ge=1, le=10)):
|
|
"""
|
|
Scrape all comments from a specific Reddit post with nested replies.
|
|
|
|
- **post_id**: Reddit post ID (without 't3_')
|
|
- **depth**: Maximum comment nesting depth (1-10)
|
|
"""
|
|
result = scraper.scrape_post_comments(post_id=post_id, depth=depth)
|
|
|
|
if "Error" in result:
|
|
raise HTTPException(status_code=500, detail=str(result["Error"]))
|
|
|
|
return result
|
|
|
|
|
|
@app.post("/scrape/custom")
|
|
async def scrape_custom(query: CustomQuery):
|
|
"""
|
|
Flexible endpoint for custom scraping queries.
|
|
|
|
- **type**: Type of scrape ('subreddit' or 'post')
|
|
- **target**: Subreddit name or post ID
|
|
- **limit**: Number of posts (for subreddit type)
|
|
- **time_range**: Time filter (for subreddit type)
|
|
- **depth**: Maximum comment nesting depth
|
|
- **include_comments**: Whether to scrape comments (set False for faster results)
|
|
"""
|
|
if query.type == "subreddit":
|
|
result = scraper.scrape_subreddit_top(
|
|
subreddit=query.target,
|
|
limit=query.limit,
|
|
time_range=query.time_range,
|
|
depth=query.depth,
|
|
include_comments=query.include_comments
|
|
)
|
|
elif query.type == "post":
|
|
if not query.include_comments:
|
|
# Just fetch post metadata without comments for faster response
|
|
result = scraper.scrape_post_comments(query.target, depth=0)
|
|
if "Error" not in result and "data" in result:
|
|
# Return empty comments list since we're skipping them
|
|
result["data"] = []
|
|
else:
|
|
result = scraper.scrape_post_comments(
|
|
post_id=query.target,
|
|
depth=query.depth
|
|
)
|
|
|
|
if "Error" in result:
|
|
raise HTTPException(status_code=500, detail=str(result["Error"]))
|
|
|
|
return result
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint with API information."""
|
|
return {
|
|
"name": "Reddit Super Duper Scraper",
|
|
"version": "1.0.0",
|
|
"description": "Scrape public Reddit data without authentication",
|
|
"endpoints": {
|
|
"/scrape/subreddit/{subreddit}": "GET - Scrape top posts from a subreddit",
|
|
"/scrape/post/{post_id}": "GET - Scrape comments from a specific post",
|
|
"/scrape/custom": "POST - Flexible custom scraping query",
|
|
"/health": "GET - Health check endpoint"
|
|
},
|
|
"docs": "/docs",
|
|
"redoc": "/redoc"
|
|
}
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments for server configuration."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Reddit Super Duper Scraper - Scrape public Reddit data via local API"
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=Config.DEFAULT_PORT,
|
|
help=f"Port to run the server on (default: {Config.DEFAULT_PORT})"
|
|
)
|
|
parser.add_argument(
|
|
"--host",
|
|
type=str,
|
|
default=Config.DEFAULT_HOST,
|
|
help=f"Host to bind to (default: {Config.DEFAULT_HOST})"
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
|
|
import uvicorn
|
|
|
|
print(f"🚀 Starting Reddit Super Duper Scraper on http://{args.host}:{args.port}")
|
|
print("📖 API documentation available at http://localhost:{}/docs".format(args.port))
|
|
print("💡 Accessible via local network only - no authentication required")
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=args.host,
|
|
port=args.port,
|
|
reload=False
|
|
)
|