"""360Shield Guard — Flask
Drop this file in your project root. Add one line to app.py. Done.

    from guard_flask import shield
    shield(app, key="YOUR_AGENT_KEY", domain="example.com")

The key MUST be an AGENT key (ada_agent_*), NOT a master key (ada_live_*).
Generate one per domain from your dashboard at https://360shield.net/dashboard
("Guard & API" tab) — the dashboard auto-creates a key for each verified domain.

Open source — https://github.com/asimetry/360shield-guard
Version: 1.0.2 | License: MIT | Zero dependencies (stdlib only)
"""

import os, sys, json, time, hashlib, hmac, threading, collections
from datetime import datetime, timezone
from urllib.request import Request, urlopen
from urllib.error import URLError

__version__ = '1.0.3'
_API = 'https://360shield.net/v1/agent'
_HB_INTERVAL = 30  # 30 seconds — frequent enough for near-realtime push commands
_FLOOD_LIMIT = 50
_FLOOD_WINDOW = 300
_BUFFER_SIZE = 500


class _Shield:
    def __init__(self, app, key, domain):
        self.app = app
        self.key = key
        self.domain = domain
        self.agent_id = None
        self._buf = collections.deque(maxlen=_BUFFER_SIZE)
        self._flood = {}
        self._blocked = set()
        self._lock = threading.Lock()
        self._executed_cmd_ids = []  # Commands executed last cycle (to ack on next heartbeat)

        # Register hooks
        app.before_request(self._before)
        app.after_request(self._after)

        # Register + start heartbeat in background
        t = threading.Thread(target=self._boot, daemon=True)
        t.start()

    # ── Request Hooks ──

    def _before(self):
        from flask import request, jsonify, g
        g._shield_t0 = time.time()
        path = request.path

        if path.startswith('/static/') or path == '/favicon.ico':
            return

        ip = self._ip(request)

        # Blocked?
        if ip in self._blocked:
            return jsonify({'error': 'blocked'}), 403

        # Flood check
        now = int(time.time())
        with self._lock:
            ts = self._flood.get(ip, [])
            ts = [t for t in ts if now - t < _FLOOD_WINDOW]
            ts.append(now)
            self._flood[ip] = ts[-200:]
            if len(ts) >= _FLOOD_LIMIT:
                return jsonify({'error': 'rate_limited', 'retry_after': 60}), 429

        # Log request
        ua = request.headers.get('User-Agent', '')
        self._buf.append({
            'time': now, 'ip': ip, 'method': request.method,
            'path': path[:80], 'ua': ua[:120],
            'is_bot': self._is_bot(ua), 'status': None, 'ms': None,
        })

    def _after(self, response):
        from flask import g
        if self._buf:
            try:
                entry = self._buf[-1]
                entry['status'] = response.status_code
                t0 = getattr(g, '_shield_t0', None)
                if t0:
                    entry['ms'] = int((time.time() - t0) * 1000)
            except (IndexError, RuntimeError):
                pass

        # Security headers
        response.headers['X-Content-Type-Options'] = 'nosniff'
        response.headers['X-Frame-Options'] = 'DENY'
        response.headers['X-XSS-Protection'] = '1; mode=block'
        response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
        response.headers.pop('Server', None)
        response.headers.pop('X-Powered-By', None)
        return response

    # ── Heartbeat ──

    def _boot(self):
        time.sleep(2)  # let app start
        self._register()
        while True:
            try:
                self._heartbeat()
            except Exception as e:
                print(f'[360Shield] heartbeat error: {e}', file=sys.stderr)
            time.sleep(_HB_INTERVAL)

    def _register(self):
        try:
            body = json.dumps({'domain': self.domain}).encode()
            resp = self._api_call('/register', body)
            if resp and resp.get('agent_id'):
                self.agent_id = resp['agent_id']
                print(f'[360Shield] Guard registered: {self.agent_id}')
        except Exception as e:
            print(f'[360Shield] register failed: {e}', file=sys.stderr)

    def _heartbeat(self):
        # Build payload (includes acked_command_ids from last cycle)
        payload = self._build_payload()
        # Snapshot acked ids that we're sending; clear local list so we don't resend
        sent_acks = list(self._executed_cmd_ids)
        self._executed_cmd_ids = []
        body = json.dumps(payload, default=str).encode()
        resp = self._api_call('/heartbeat', body, sign=True)
        if resp:
            self._process_commands(resp.get('commands', []))
        else:
            # Heartbeat failed — restore acks so we retry next cycle
            self._executed_cmd_ids = sent_acks + self._executed_cmd_ids

    def _build_payload(self):
        reqs = list(self._buf)
        now = int(time.time())
        window = [r for r in reqs if now - r['time'] < 300]

        total = len(window)
        bots = sum(1 for r in window if r.get('is_bot'))
        ips = set(r['ip'] for r in window)
        status_counts = {}
        for r in window:
            s = r.get('status')
            if s:
                status_counts[str(s)] = status_counts.get(str(s), 0) + 1

        # ── Sample recent_requests for server-side AoM (Seans 10) ──
        # Cap at 500 to keep payload size sane (5 min of moderate traffic).
        # Server normalizes ts → AoM expects float epoch.
        sampled = window[-500:] if len(window) > 500 else window
        recent_requests = [{
            'ts': float(r.get('time', 0)),
            'ip': r.get('ip', ''),
            'method': r.get('method', 'GET'),
            'path': r.get('path', '/'),
            'status': r.get('status') or 200,
            'is_bot': bool(r.get('is_bot')),
            'ua': r.get('ua', ''),
        } for r in sampled]

        # Security config snapshot
        security = self._security_snapshot()

        return {
            'agent_id': self.agent_id or 'pending',
            'domain': self.domain,
            'guard_version': __version__,
            'timestamp': datetime.now(timezone.utc).isoformat(),
            'traffic': {
                'total_requests': total,
                'bot_requests': bots,
                'unique_ips': len(ips),
                'status_codes': status_counts,
                'rpm': round(total / 5, 1) if total else 0,
            },
            'recent_requests': recent_requests,
            'security': security,
            'guard_state': {
                'blocked_count': len(self._blocked),
                'buffer_size': len(self._buf),
                'flood_tracked_ips': len(self._flood),
                'maintenance_mode': getattr(self, '_maintenance', False),
                'bot_patterns': 25,
            },
            'acked_command_ids': self._executed_cmd_ids[:],  # commands executed since last heartbeat
        }

    def _security_snapshot(self):
        """Read Flask app security config — no user data, only config."""
        snap = {'platform': 'flask'}
        try:
            c = self.app.config
            snap['debug_mode'] = c.get('DEBUG', False)
            snap['secret_key_length'] = len(str(c.get('SECRET_KEY', '')))
            snap['session_cookie_httponly'] = c.get('SESSION_COOKIE_HTTPONLY', False)
            snap['session_cookie_secure'] = c.get('SESSION_COOKIE_SECURE', False)
            snap['session_cookie_samesite'] = c.get('SESSION_COOKIE_SAMESITE', '')
            snap['max_content_length'] = c.get('MAX_CONTENT_LENGTH')
            snap['csrf_protection'] = 'csrf' in str(c.get('extensions', '')).lower() or hasattr(self.app, 'csrf')
        except Exception:
            pass

        # Package versions
        try:
            import importlib.metadata as meta
            snap['packages'] = {d.metadata['Name']: d.version for d in meta.distributions()
                                if d.metadata['Name'] in ('flask', 'werkzeug', 'gunicorn', 'jinja2')}
        except Exception:
            snap['packages'] = {}

        return snap

    # ── Commands ──

    def _process_commands(self, commands):
        """Execute commands from server. Track IDs of executed commands so we
        can ack them on the next heartbeat (so server can mark them done in queue)."""
        executed_ids = []
        for cmd in commands:
            action = cmd.get('action', '')
            cmd_id = cmd.get('id')
            try:
                if action == 'block_ip':
                    ip = cmd.get('ip', '')
                    if ip:
                        self._blocked.add(ip)
                        print(f'[360Shield] blocked IP: {ip}')
                elif action == 'unblock_ip':
                    ip = cmd.get('ip', '')
                    self._blocked.discard(ip)
                    if ip:
                        print(f'[360Shield] unblocked IP: {ip}')
                elif action == 'reset_blocked':
                    n = len(self._blocked)
                    self._blocked.clear()
                    print(f'[360Shield] reset all blocked IPs ({n})')
                elif action == 'update_flood_limit':
                    global _FLOOD_LIMIT
                    _FLOOD_LIMIT = cmd.get('limit', 50)
                    print(f'[360Shield] flood limit: {_FLOOD_LIMIT}')
                elif action == 'force_scan':
                    # Trigger immediate AoM/SC re-evaluation on next heartbeat
                    # (no-op here — heartbeat itself sends fresh data)
                    print('[360Shield] force_scan acknowledged')
                elif action == 'set_maintenance':
                    # User can trigger remote maintenance lockdown
                    # Stored in self for next request inspection
                    self._maintenance = bool(cmd.get('enabled', False))
                    print(f'[360Shield] maintenance mode: {self._maintenance}')

                if cmd_id is not None:
                    executed_ids.append(cmd_id)
            except Exception as e:
                print(f'[360Shield] command error ({action}): {e}')

        if executed_ids:
            self._executed_cmd_ids.extend(executed_ids)

    # ── Helpers ──

    def _ip(self, request):
        for h in ('CF-Connecting-IP', 'X-Real-IP', 'X-Forwarded-For'):
            v = request.headers.get(h)
            if v:
                return v.split(',')[0].strip()
        return request.remote_addr or ''

    def _is_bot(self, ua):
        ua = ua.lower()
        bots = ('bot', 'crawler', 'spider', 'scraper', 'curl', 'wget',
                'python-requests', 'go-http', 'headless', 'phantom',
                'selenium', 'puppeteer', 'postman', 'httpie')
        return any(b in ua for b in bots) or not ua

    def _api_call(self, endpoint, body, sign=False):
        headers = {
            'Content-Type': 'application/json',
            'Authorization': f'Bearer {self.key}',
            'X-Guard-Version': __version__,
            'User-Agent': f'360Shield-Guard/{__version__}',
        }
        if sign:
            sig = hmac.new(self.key.encode(), body, hashlib.sha256).hexdigest()
            headers['X-Guard-Signature'] = sig

        req = Request(_API + endpoint, data=body, headers=headers, method='POST')
        try:
            with urlopen(req, timeout=10) as r:
                return json.loads(r.read())
        except URLError as e:
            print(f'[360Shield] API error: {e}', file=sys.stderr)
            return None

    def _sign(self, body):
        return hmac.new(self.key.encode(), body, hashlib.sha256).hexdigest()


def shield(app, key=None, domain=None):
    """One-line Guard activation.

        from guard_flask import shield
        shield(app, key="ada_live_xxx", domain="example.com")

    Or use environment variables:
        SHIELD_AGENT_KEY=ada_agent_xxx SHIELD_DOMAIN=example.com
    """
    key = key or os.environ.get('SHIELD_AGENT_KEY', '')
    domain = domain or os.environ.get('SHIELD_DOMAIN', '')
    if not key or not domain:
        print('[360Shield] WARNING: SHIELD_AGENT_KEY and SHIELD_DOMAIN required. Guard NOT active.', file=sys.stderr)
        return None
    return _Shield(app, key, domain)
