#!/usr/bin/env python3
"""
dj_edit_worker.py — Smart Intro/Outro Generator Worker

Usage:
    python3 dj_edit_worker.py <job_json_base64>

Job JSON contains: job_id, s3_key, filename, user_id, edit_type, intro_bars, outro_bars, fade_outro, bpm_hint, callback_url

Processing pipeline:
    1. Download original from S3
    2. Detect BPM via librosa
    3. Calculate beat grid and bar positions
    4. Generate intro/outro by looping beat-aligned segments
    5. Export to MP3
    6. Upload to S3
    7. Callback to rpool.id with result
"""

import sys
import os
import json
import base64
import tempfile
import shutil
import subprocess
import time
import traceback
import urllib.request

import numpy as np
import librosa
import soundfile as sf
import boto3
from botocore.config import Config as BotoConfig
from scipy.signal import butter, sosfilt

# ── S3 Configuration (same Wasabi bucket as RPOOL) ──
S3_BUCKET = 'rpool-storage'
S3_REGION = 'ap-northeast-1'
S3_ENDPOINT = f'https://s3.{S3_REGION}.wasabisys.com'
S3_ACCESS_KEY = os.environ.get('S3_ACCESS_KEY', 'D2CL9JKNNI0TLTXB9CE9')
S3_SECRET_KEY = os.environ.get('S3_SECRET_KEY', 'kBm3C67E6tnme7POkkGLXXwlDB1gT3rYecCII9Vj')

def get_s3_client():
    return boto3.client(
        's3',
        endpoint_url=S3_ENDPOINT,
        region_name=S3_REGION,
        aws_access_key_id=S3_ACCESS_KEY,
        aws_secret_access_key=S3_SECRET_KEY,
        config=BotoConfig(signature_version='s3v4'),
    )


def log(msg):
    print(f"[dj_edit] {msg}", flush=True)


def send_callback(callback_url, data):
    """Send status update back to rpool.id"""
    if not callback_url:
        return
    try:
        payload = json.dumps(data).encode('utf-8')
        req = urllib.request.Request(callback_url, data=payload,
                                     headers={'Content-Type': 'application/json'})
        urllib.request.urlopen(req, timeout=10)
    except Exception as e:
        log(f"Callback failed: {e}")


def send_progress(callback_url, api_key, job_id, progress, message):
    send_callback(callback_url, {
        'api_key': api_key,
        'job_id': job_id,
        'status': 'processing',
        'progress': progress,
        'progress_message': message,
    })


def download_from_s3(s3_key, dest_path):
    """Download file from Wasabi S3 using boto3"""
    log(f"Downloading s3://{S3_BUCKET}/{s3_key}")
    s3 = get_s3_client()
    s3.download_file(S3_BUCKET, s3_key, dest_path)
    if not os.path.exists(dest_path) or os.path.getsize(dest_path) < 1000:
        raise RuntimeError("Downloaded file is missing or too small")


def upload_to_s3(local_path, s3_key):
    """Upload file to Wasabi S3 using boto3"""
    log(f"Uploading to s3://{S3_BUCKET}/{s3_key}")
    s3 = get_s3_client()
    s3.upload_file(local_path, S3_BUCKET, s3_key, ExtraArgs={'ContentType': 'audio/mpeg'})


def detect_bpm(audio_path, bpm_hint=None):
    """Detect BPM using librosa. Uses hint if provided to resolve octave errors."""
    log("Detecting BPM...")
    y, sr = librosa.load(audio_path, sr=22050, mono=True, duration=60)
    tempo, _ = librosa.beat.beat_track(y=y, sr=sr)

    # librosa may return an array
    if hasattr(tempo, '__len__'):
        bpm = float(tempo[0])
    else:
        bpm = float(tempo)

    # Resolve octave error using hint
    if bpm_hint and bpm_hint > 0:
        hint = float(bpm_hint)
        candidates = [bpm, bpm * 2, bpm / 2]
        bpm = min(candidates, key=lambda x: abs(x - hint))

    # Sanity: clamp to DJ-reasonable range
    while bpm < 70:
        bpm *= 2
    while bpm > 200:
        bpm /= 2

    bpm = round(bpm, 2)
    log(f"Detected BPM: {bpm}")
    return bpm


def get_beat_frames(audio_path, bpm):
    """Get beat frame positions aligned to detected BPM."""
    y, sr = librosa.load(audio_path, sr=22050, mono=True)
    _, beat_frames = librosa.beat.beat_track(y=y, sr=sr, bpm=bpm)
    beat_times = librosa.frames_to_time(beat_frames, sr=sr)
    return beat_times, len(y) / sr


def calculate_bar_duration(bpm):
    """Calculate duration of one bar (4 beats) in seconds."""
    beat_dur = 60.0 / bpm
    return beat_dur * 4


def load_audio_full(audio_path):
    """Load full audio at native sample rate using ffmpeg → WAV pipe."""
    log("Loading full audio...")
    tmp_wav = audio_path + '.decode.wav'
    cmd = ['ffmpeg', '-y', '-i', audio_path, '-acodec', 'pcm_s16le', '-ar', '44100', '-ac', '2', tmp_wav]
    subprocess.run(cmd, capture_output=True, timeout=120)
    if not os.path.exists(tmp_wav):
        raise RuntimeError("Failed to decode audio to WAV")
    data, sr = sf.read(tmp_wav)
    os.remove(tmp_wav)
    return data, sr


def crossfade_segments(seg_a, seg_b, fade_samples):
    """Apply crossfade between end of seg_a and start of seg_b."""
    fade_samples = min(fade_samples, len(seg_a), len(seg_b))
    if fade_samples <= 0:
        return np.concatenate([seg_a, seg_b])

    out_a = seg_a[:-fade_samples]
    xfade_a = seg_a[-fade_samples:]
    xfade_b = seg_b[:fade_samples]
    rest_b = seg_b[fade_samples:]

    fade_out = np.linspace(1.0, 0.0, fade_samples)
    fade_in = np.linspace(0.0, 1.0, fade_samples)

    if xfade_a.ndim == 2:
        fade_out = fade_out[:, np.newaxis]
        fade_in = fade_in[:, np.newaxis]

    xfaded = xfade_a * fade_out + xfade_b * fade_in
    return np.concatenate([out_a, xfaded, rest_b])


def reduce_vocals(audio_data, sr):
    """
    Remove vocals from audio to create an instrumental version.
    
    Strategy:
    1. For STEREO: center-channel cancellation (vocals are panned center)
       + HPSS cleanup on residual center content
    2. For MONO: aggressive HPSS keeping only percussive + low harmonic
    
    This produces a much cleaner instrumental than simple HPSS alone.
    """
    log(f"Removing vocals (stereo={audio_data.ndim == 2})...")

    if audio_data.ndim == 2:
        return remove_vocals_stereo(audio_data, sr)
    else:
        return remove_vocals_mono(audio_data, sr)


def remove_vocals_stereo(stereo, sr):
    """
    Remove vocals from stereo audio using center-channel cancellation.
    Vocals are almost always panned dead center in a stereo mix.
    Subtracting L-R removes the center, then we add back the bass
    and percussive elements that were also in the center.
    """
    left = stereo[:, 0]
    right = stereo[:, 1]

    # Center cancellation: (L - R) removes center-panned content (vocals)
    # This gives us the "sides" — instruments panned left/right
    sides = (left - right) / 2.0

    # The center channel (what we're removing)
    center = (left + right) / 2.0

    # From the center, extract only the percussive elements (drums/kick)
    # and LOW frequencies (bass), discard harmonic (vocals/melody)
    center_stft = librosa.stft(center)
    harmonic_stft, percussive_stft = librosa.decompose.hpss(center_stft, margin=3.0)

    # Keep center percussion fully (kick, snare, hats are center-panned)
    center_perc = librosa.istft(percussive_stft, length=len(center))

    # Keep only the sub-bass from center harmonic (< 200Hz)
    # Vocals are typically 200Hz-4kHz, so cutting above 200Hz removes them
    freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
    bass_mask = np.zeros_like(freqs)
    bass_cutoff = 200  # Hz
    bass_mask[freqs <= bass_cutoff] = 1.0
    # Smooth transition to avoid artifacts
    transition = (freqs > bass_cutoff) & (freqs <= bass_cutoff * 1.5)
    bass_mask[transition] = 1.0 - (freqs[transition] - bass_cutoff) / (bass_cutoff * 0.5)

    center_bass_stft = harmonic_stft * bass_mask[:, np.newaxis]
    center_bass = librosa.istft(center_bass_stft, length=len(center))

    # Reconstruct: sides (L-R) + center percussion + center bass
    instrumental_mono = sides + center_perc + center_bass

    # Make stereo by creating a slight stereo image
    # (sides already have stereo info, center elements go to both channels)
    left_out = instrumental_mono + sides * 0.3
    right_out = instrumental_mono - sides * 0.3

    result = np.column_stack([left_out, right_out])
    return result.astype(np.float64)


def remove_vocals_mono(mono, sr):
    """
    Remove vocals from mono audio using aggressive HPSS.
    Less effective than stereo method, but still removes a lot.
    Keeps only: full percussion + bass frequencies from harmonic.
    """
    stft = librosa.stft(mono)
    harmonic_stft, percussive_stft = librosa.decompose.hpss(stft, margin=3.0)

    percussive = librosa.istft(percussive_stft, length=len(mono))

    # From harmonic, keep only sub-bass (< 200Hz) — removes vocal frequencies
    freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
    bass_mask = np.zeros_like(freqs)
    bass_cutoff = 200
    bass_mask[freqs <= bass_cutoff] = 1.0
    transition = (freqs > bass_cutoff) & (freqs <= bass_cutoff * 1.5)
    bass_mask[transition] = 1.0 - (freqs[transition] - bass_cutoff) / (bass_cutoff * 0.5)

    harmonic_bass_stft = harmonic_stft * bass_mask[:, np.newaxis]
    harmonic_bass = librosa.istft(harmonic_bass_stft, length=len(mono))

    result = percussive + harmonic_bass
    return result.astype(np.float64)


def find_best_instrumental_section(audio_data, sr, bpm, n_source_bars=4):
    """
    Find the most beat-heavy / instrumental section of the track.
    Scans the track in 2-bar windows, scores each by percussive energy
    relative to harmonic energy. The section with the highest beat ratio
    is the best candidate for a DJ intro loop (usually the chorus/reff
    instrumental part).

    Returns the audio slice (n_source_bars bars long) and its position.
    """
    log("Finding best instrumental section...")
    bar_dur = calculate_bar_duration(bpm)
    window_samples = int(n_source_bars * bar_dur * sr)

    # Load mono for analysis
    if audio_data.ndim == 2:
        mono = np.mean(audio_data, axis=1)
    else:
        mono = audio_data

    # HPSS: separate into harmonic (vocals/melody) and percussive (drums/beats)
    stft = librosa.stft(mono)
    harmonic, percussive = librosa.decompose.hpss(stft)
    perc_audio = librosa.istft(percussive)
    harm_audio = librosa.istft(harmonic)

    # Make sure they're same length
    min_len = min(len(perc_audio), len(harm_audio), len(mono))
    perc_audio = perc_audio[:min_len]
    harm_audio = harm_audio[:min_len]

    # Slide a window across the track in 2-bar steps, score each
    step_samples = int(2 * bar_dur * sr)
    best_score = -1
    best_pos = 0
    scores = []

    # Skip first 10% and last 10% of track (usually quiet intro/outro)
    start_search = int(min_len * 0.10)
    end_search = int(min_len * 0.90)

    pos = start_search
    while pos + window_samples <= end_search:
        perc_rms = np.sqrt(np.mean(perc_audio[pos:pos + window_samples] ** 2))
        harm_rms = np.sqrt(np.mean(harm_audio[pos:pos + window_samples] ** 2))
        total_rms = np.sqrt(np.mean(mono[pos:pos + window_samples] ** 2))

        # Score: high percussive energy + high total energy + lower harmonic ratio
        # We want loud, beat-heavy sections (the chorus instrumental)
        if harm_rms > 0 and total_rms > 0:
            beat_ratio = perc_rms / (harm_rms + 1e-10)
            energy_score = total_rms
            score = beat_ratio * energy_score
        else:
            score = 0

        scores.append((pos, score, perc_rms, harm_rms))

        if score > best_score:
            best_score = score
            best_pos = pos

        pos += step_samples

    # Snap to bar boundary for cleaner loops
    best_time = best_pos / sr
    log(f"Best section found at {best_time:.1f}s (score: {best_score:.4f})")

    # Extract section from full (stereo) audio
    end_pos = min(best_pos + window_samples, len(audio_data))
    section = audio_data[best_pos:end_pos].copy()

    return section, best_pos


def apply_filter_sweep(audio_data, sr, start_freq=200, end_freq=18000, direction='open'):
    """
    Apply a gradual low-pass filter sweep across the audio.
    'open': starts filtered (start_freq), ends open (end_freq) — for intros
    'close': starts open, ends filtered — for outros
    """
    n_samples = len(audio_data)
    chunk_size = sr // 4  # process in 250ms chunks for smooth sweep
    n_chunks = max(1, n_samples // chunk_size)

    if direction == 'open':
        freqs = np.linspace(start_freq, end_freq, n_chunks)
    else:
        freqs = np.linspace(end_freq, start_freq, n_chunks)

    result = np.zeros_like(audio_data)
    nyq = sr / 2.0

    for i in range(n_chunks):
        start = i * chunk_size
        end = min(start + chunk_size, n_samples)
        chunk = audio_data[start:end]

        cutoff = min(freqs[i] / nyq, 0.99)
        cutoff = max(cutoff, 0.01)

        sos = butter(3, cutoff, btype='low', output='sos')
        if chunk.ndim == 2:
            for ch in range(chunk.shape[1]):
                result[start:end, ch] = sosfilt(sos, chunk[:, ch])
        else:
            result[start:end] = sosfilt(sos, chunk)

    # Handle remaining samples
    remaining = n_samples - n_chunks * chunk_size
    if remaining > 0:
        result[-remaining:] = audio_data[-remaining:]

    return result


def generate_beat_intro(audio_data, sr, bpm, n_bars):
    """
    Generate a DJ-style beat intro:
    1. Find the best instrumental/beat section (usually chorus) via HPSS analysis
    2. Reduce vocals from that section (keep drums, bass, synths)
    3. Loop for n_bars with filter sweep building up
    4. Last 2 bars: blend back to full mix for smooth drop into original track
    Result: DJ hears the chorus groove (instrumental) building up, then original drops.
    """
    bar_dur = calculate_bar_duration(bpm)
    source_bars = 4  # take 4 bars for a musically interesting loop
    segment_samples = int(source_bars * bar_dur * sr)
    total_samples = int(n_bars * bar_dur * sr)

    # Find best beat-heavy section (usually chorus/reff)
    source, source_pos = find_best_instrumental_section(audio_data, sr, bpm, n_source_bars=source_bars)
    log(f"Using section at {source_pos/sr:.1f}s as intro source ({len(source)/sr:.1f}s)")

    # Pad/trim source to exact length
    if len(source) < segment_samples:
        while len(source) < segment_samples:
            source = np.concatenate([source, source])
    source = source[:segment_samples]

    # Remove vocals — keeps only drums + bass (pure instrumental)
    instrumental = reduce_vocals(source, sr)

    # Use 2-bar chunks for looping
    chunk_samples = int(2 * bar_dur * sr)
    chunks = []
    for i in range(0, len(instrumental), chunk_samples):
        c = instrumental[i:i + chunk_samples]
        if len(c) >= chunk_samples * 0.8:
            chunks.append(c[:chunk_samples] if len(c) >= chunk_samples else c)
    if not chunks:
        chunks = [instrumental[:chunk_samples]]

    # Build loop by cycling through instrumental chunks
    loop_parts = []
    total_built = 0
    chunk_idx = 0
    while total_built < total_samples:
        loop_parts.append(chunks[chunk_idx % len(chunks)])
        total_built += len(chunks[chunk_idx % len(chunks)])
        chunk_idx += 1
    loop = np.concatenate(loop_parts)[:total_samples]

    # Last 2 bars: crossfade from instrumental → full mix (creates the "build" before drop)
    blend_bars = min(2, n_bars - 1)
    blend_samples = int(blend_bars * bar_dur * sr)
    blend_samples = min(blend_samples, total_samples)

    if blend_samples > 0 and total_samples > blend_samples:
        # Build full-mix source for the blend region
        full_chunks = []
        for i in range(0, len(source), chunk_samples):
            c = source[i:i + chunk_samples]
            if len(c) >= chunk_samples * 0.8:
                full_chunks.append(c[:chunk_samples] if len(c) >= chunk_samples else c)
        if not full_chunks:
            full_chunks = [source[:chunk_samples]]

        full_parts = []
        total_full = 0
        cidx = 0
        while total_full < blend_samples:
            full_parts.append(full_chunks[cidx % len(full_chunks)])
            total_full += len(full_chunks[cidx % len(full_chunks)])
            cidx += 1
        full_blend = np.concatenate(full_parts)[:blend_samples]

        fade_in = np.linspace(0.0, 1.0, blend_samples)
        fade_out_curve = np.linspace(1.0, 0.0, blend_samples)
        if loop.ndim == 2:
            fade_in = fade_in[:, np.newaxis]
            fade_out_curve = fade_out_curve[:, np.newaxis]

        loop[-blend_samples:] = loop[-blend_samples:] * fade_out_curve + full_blend * fade_in

    # Apply filter sweep: starts muffled, opens up
    loop = apply_filter_sweep(loop, sr, start_freq=400, end_freq=18000, direction='open')

    # Normalize volume to match original (slightly quieter so the drop hits)
    orig_rms = np.sqrt(np.mean(source ** 2))
    loop_rms = np.sqrt(np.mean(loop ** 2))
    if loop_rms > 0:
        loop = loop * (orig_rms / loop_rms) * 0.85

    log(f"Beat intro generated: {n_bars} bars, {len(loop)/sr:.1f}s (source from {source_pos/sr:.1f}s)")
    return loop


def generate_beat_outro(audio_data, sr, bpm, n_bars, fade_out=True):
    """
    Generate a DJ-style beat outro:
    1. Find best beat section, reduce vocals
    2. First 2 bars: crossfade from full mix to instrumental
    3. Loop instrumental for remaining bars
    4. Filter sweep: open → closed
    5. Optionally fade out
    """
    bar_dur = calculate_bar_duration(bpm)
    source_bars = 4
    segment_samples = int(source_bars * bar_dur * sr)
    total_samples = int(n_bars * bar_dur * sr)

    source, source_pos = find_best_instrumental_section(audio_data, sr, bpm, n_source_bars=source_bars)
    if len(source) < segment_samples:
        while len(source) < segment_samples:
            source = np.concatenate([source, source])
    source = source[:segment_samples]

    instrumental = reduce_vocals(source, sr)

    chunk_samples = int(2 * bar_dur * sr)
    chunks = []
    for i in range(0, len(instrumental), chunk_samples):
        c = instrumental[i:i + chunk_samples]
        if len(c) >= chunk_samples * 0.8:
            chunks.append(c[:chunk_samples] if len(c) >= chunk_samples else c)
    if not chunks:
        chunks = [instrumental[:chunk_samples]]

    loop_parts = []
    total_built = 0
    chunk_idx = 0
    while total_built < total_samples:
        loop_parts.append(chunks[chunk_idx % len(chunks)])
        total_built += len(chunks[chunk_idx % len(chunks)])
        chunk_idx += 1
    loop = np.concatenate(loop_parts)[:total_samples]

    # First 2 bars: blend from full → instrumental
    blend_bars = min(2, n_bars - 1)
    blend_samples = int(blend_bars * bar_dur * sr)
    blend_samples = min(blend_samples, total_samples)

    if blend_samples > 0:
        full_chunks = []
        for i in range(0, len(source), chunk_samples):
            c = source[i:i + chunk_samples]
            if len(c) >= chunk_samples * 0.8:
                full_chunks.append(c[:chunk_samples] if len(c) >= chunk_samples else c)
        if not full_chunks:
            full_chunks = [source[:chunk_samples]]

        full_parts = []
        total_full = 0
        cidx = 0
        while total_full < blend_samples:
            full_parts.append(full_chunks[cidx % len(full_chunks)])
            total_full += len(full_chunks[cidx % len(full_chunks)])
            cidx += 1
        full_blend = np.concatenate(full_parts)[:blend_samples]

        fade_in = np.linspace(0.0, 1.0, blend_samples)
        fade_out_curve = np.linspace(1.0, 0.0, blend_samples)
        if loop.ndim == 2:
            fade_in = fade_in[:, np.newaxis]
            fade_out_curve = fade_out_curve[:, np.newaxis]

        loop[:blend_samples] = full_blend * fade_out_curve + loop[:blend_samples] * fade_in

    # Filter sweep: open → closed
    loop = apply_filter_sweep(loop, sr, start_freq=400, end_freq=18000, direction='close')

    # Fade out
    if fade_out:
        fade_secs = min(bar_dur * 4, 8.0)
        fade_samples_count = int(fade_secs * sr)
        fade_samples_count = min(fade_samples_count, len(loop))
        fade_curve = np.linspace(1.0, 0.0, fade_samples_count)
        if loop.ndim == 2:
            fade_curve = fade_curve[:, np.newaxis]
        loop[-fade_samples_count:] *= fade_curve

    # Normalize
    orig_rms = np.sqrt(np.mean(source ** 2))
    loop_rms = np.sqrt(np.mean(loop ** 2))
    if loop_rms > 0:
        loop = loop * (orig_rms / loop_rms) * 0.85

    log(f"Beat outro generated: {n_bars} bars, {len(loop)/sr:.1f}s")
    return loop


def export_mp3(wav_path, mp3_path, bitrate='320k'):
    """Convert WAV to MP3 using ffmpeg."""
    log(f"Exporting MP3 at {bitrate}...")
    cmd = ['ffmpeg', '-y', '-i', wav_path, '-codec:a', 'libmp3lame', '-b:a', bitrate, '-q:a', '0', mp3_path]
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
    if result.returncode != 0:
        raise RuntimeError(f"MP3 export failed: {result.stderr[:500]}")


def process_job(job):
    """Main processing pipeline."""
    job_id = job['job_id']
    s3_key = job['s3_key']
    filename = job['filename']
    edit_type = job['edit_type']
    intro_bars = job.get('intro_bars')
    outro_bars = job.get('outro_bars')
    fade_outro = job.get('fade_outro', 1)
    bpm_hint = job.get('bpm_hint')
    callback_url = job.get('callback_url', '')
    api_key = job.get('api_key', '')
    user_id = job.get('user_id', 0)

    tmpdir = tempfile.mkdtemp(prefix=f'djedit_{job_id}_')
    log(f"Working in {tmpdir}")

    try:
        # 1. Download
        send_progress(callback_url, api_key, job_id, 10, 'Downloading audio...')
        ext = os.path.splitext(s3_key)[1] or '.mp3'
        orig_path = os.path.join(tmpdir, f'original{ext}')
        download_from_s3(s3_key, orig_path)

        # 2. Detect BPM
        send_progress(callback_url, api_key, job_id, 25, 'Analyzing BPM...')
        bpm = detect_bpm(orig_path, bpm_hint)

        send_progress(callback_url, api_key, job_id, 40, f'BPM: {bpm} — Generating edit...')

        # 3. Load full audio
        audio, sr = load_audio_full(orig_path)
        log(f"Audio loaded: {len(audio)} samples, {sr}Hz, {audio.ndim}D, duration={len(audio)/sr:.1f}s")

        # 4. Generate edits
        result = audio.copy()

        if edit_type in ('intro', 'intro_outro'):
            bars = int(intro_bars or 8)
            send_progress(callback_url, api_key, job_id, 50, f'Building {bars}-bar beat intro...')
            log(f"Generating {bars}-bar beat intro...")
            intro = generate_beat_intro(audio, sr, bpm, bars)
            # Crossfade intro into original track (short seam)
            xfade = int(0.1 * sr)  # 100ms crossfade
            result = crossfade_segments(intro, result, xfade)
            log(f"Intro prepended. Total duration: {len(result)/sr:.1f}s")

        if edit_type in ('outro', 'intro_outro'):
            bars = int(outro_bars or 8)
            send_progress(callback_url, api_key, job_id, 70, f'Building {bars}-bar beat outro...')
            log(f"Generating {bars}-bar beat outro...")
            outro = generate_beat_outro(audio, sr, bpm, bars, fade_out=bool(fade_outro))
            xfade = int(0.1 * sr)
            result = crossfade_segments(result, outro, xfade)
            log(f"Outro appended. Total duration: {len(result)/sr:.1f}s")

        # 5. Export
        send_progress(callback_url, api_key, job_id, 85, 'Exporting MP3...')
        wav_path = os.path.join(tmpdir, 'output.wav')
        mp3_path = os.path.join(tmpdir, 'output.mp3')
        sf.write(wav_path, result, sr)
        export_mp3(wav_path, mp3_path)

        if not os.path.exists(mp3_path) or os.path.getsize(mp3_path) < 1000:
            raise RuntimeError("Output MP3 is too small or missing")

        # 6. Upload to S3
        send_progress(callback_url, api_key, job_id, 92, 'Uploading result...')
        base_name = os.path.splitext(filename)[0]
        edit_label = edit_type.replace('_', '-')
        bar_label = ''
        if intro_bars:
            bar_label += f'{intro_bars}bar'
        if outro_bars:
            bar_label += f'-{outro_bars}bar' if bar_label else f'{outro_bars}bar'
        output_filename = f'{base_name} ({edit_label} {bar_label}) - RPOOL DJ Edit.mp3'
        output_s3_key = f'dj-edits/output/{user_id}/{job_id}/{output_filename}'
        upload_to_s3(mp3_path, output_s3_key)

        # 7. Success callback
        log(f"Done! Output: {output_s3_key}")
        send_callback(callback_url, {
            'api_key': api_key,
            'job_id': job_id,
            'status': 'completed',
            'output_s3_key': output_s3_key,
            'output_filename': output_filename,
            'bpm': bpm,
            'progress': 100,
            'progress_message': 'Complete',
        })

    except Exception as e:
        log(f"ERROR: {e}")
        traceback.print_exc()
        send_callback(callback_url, {
            'api_key': api_key,
            'job_id': job_id,
            'status': 'failed',
            'error_message': str(e)[:450],
        })
    finally:
        shutil.rmtree(tmpdir, ignore_errors=True)
        log(f"Cleaned up {tmpdir}")


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: python3 dj_edit_worker.py <job_json_base64>")
        sys.exit(1)

    job_b64 = sys.argv[1]
    job_data = json.loads(base64.b64decode(job_b64))
    log(f"Processing job #{job_data.get('job_id')} — {job_data.get('edit_type')} edit")
    process_job(job_data)
