import re, os, sys

src = open('StockTrading.py').read()
changes = 0

# ── REVERT 1: cooldown back to 5 seconds ──────────────────────────────────────
if 'BUY_COOLDOWN_SECS      = 120' in src:
    src = src.replace(
        'BUY_COOLDOWN_SECS      = 120   # 2-min base; per-agent override in __init__',
        'BUY_COOLDOWN_SECS      = 5'
    )
    changes += 1
    print("REVERT 1: BUY_COOLDOWN_SECS back to 5")

# ── REVERT 2: remove per-agent cooldown_secs from TradingAgent.__init__ ───────
OLD_COOLDOWN_BLOCK = (
    '        # Per-agent cooldown: HighFreq=60s, LowFreq=300s, default=120s\n'
    '        _nm = self.name.lower()\n'
    '        if any(x in _nm for x in ("highfreq","ultraaggress","momentum")):\n'
    '            self.cooldown_secs = 60\n'
    '        elif any(x in _nm for x in ("lowfreq","conservative","ultraconserv","meanreversion")):\n'
    '            self.cooldown_secs = 300\n'
    '        else:\n'
    '            self.cooldown_secs = 120\n'
)
if OLD_COOLDOWN_BLOCK in src:
    src = src.replace(OLD_COOLDOWN_BLOCK, '')
    changes += 1
    print("REVERT 2: per-agent cooldown_secs removed")

# ── REVERT 3: buy/short back to BUY_COOLDOWN_SECS directly ───────────────────
src = src.replace(
    'if time.time() - self._last_sold.get(sym, 0) < getattr(self, "cooldown_secs", BUY_COOLDOWN_SECS):',
    'if time.time() - self._last_sold.get(sym, 0) < BUY_COOLDOWN_SECS:'
)
src = src.replace(
    'if time.time() - self._last_sold.get(sym + "_short", 0) < getattr(self, "cooldown_secs", BUY_COOLDOWN_SECS):',
    'if time.time() - self._last_sold.get(sym + "_short", 0) < BUY_COOLDOWN_SECS:'
)
changes += 1
print("REVERT 3: buy/short use BUY_COOLDOWN_SECS directly")

# ── REVERT 4: remove market impact constants from TradingSystem.__init__ ──────
OLD_IMPACT_INIT = (
    '        # Market impact: max agents that may hold same stock simultaneously\n'
    '        self._IMPACT_MAX_AGENTS   = 3\n'
    '        self._IMPACT_CONF_PENALTY = 0.05\n'
)
if OLD_IMPACT_INIT in src:
    src = src.replace(OLD_IMPACT_INIT, '')
    changes += 1
    print("REVERT 4: market impact constants removed from __init__")

# ── REVERT 5: remove market impact check from _trading_pass ───────────────────
# Replace the impact-gated buy block with simple direct buy
OLD_IMPACT_BUY = (
    '                if d == 1:   # bullish\n'
    '                    if sym in ag.short_positions:\n'
    '                        ag.cover(sym, price, reason="signal")\n'
    '                    elif sym not in ag.positions and self.allow_long:\n'
    '                        holders = sum(1 for a in self.agents if sym in a.positions)\n'
    '                        if holders < self._IMPACT_MAX_AGENTS:\n'
    '                            adj = ag.min_conf + holders * self._IMPACT_CONF_PENALTY\n'
    '                            if eff_cf >= adj:\n'
    '                                _, pred_pct, _ = ag.model.predict(vecs[i])\n'
    '                                ag.buy(sym, price, pred_pct or 0.0, cf, meta_cf, eff_cf, atr_pct)\n'
    '\n'
    '                elif d == 0:  # bearish\n'
    '                    if sym in ag.positions:\n'
    '                        ag.sell(sym, price, reason="signal")\n'
    '                    elif sym not in ag.short_positions and self.allow_short:\n'
    '                        s_holders = sum(1 for a in self.agents if sym in a.short_positions)\n'
    '                        if s_holders < self._IMPACT_MAX_AGENTS:\n'
    '                            adj = ag.min_conf + s_holders * self._IMPACT_CONF_PENALTY\n'
    '                            if eff_cf >= adj:\n'
    '                                _, pred_pct, _ = ag.model.predict(vecs[i])\n'
    '                                ag.short(sym, price, pred_pct or 0.0, cf, meta_cf, eff_cf)'
)
NEW_CLEAN_BUY = (
    '                if d == 1:   # bullish\n'
    '                    if sym in ag.short_positions:\n'
    '                        ag.cover(sym, price, reason="signal")\n'
    '                    elif sym not in ag.positions and self.allow_long:\n'
    '                        _, pred_pct, _ = ag.model.predict(vecs[i])\n'
    '                        ag.buy(sym, price, pred_pct or 0.0, cf, meta_cf, eff_cf, atr_pct)\n'
    '\n'
    '                elif d == 0:  # bearish\n'
    '                    if sym in ag.positions:\n'
    '                        ag.sell(sym, price, reason="signal")\n'
    '                    elif sym not in ag.short_positions and self.allow_short:\n'
    '                        _, pred_pct, _ = ag.model.predict(vecs[i])\n'
    '                        ag.short(sym, price, pred_pct or 0.0, cf, meta_cf, eff_cf)'
)
if OLD_IMPACT_BUY in src:
    src = src.replace(OLD_IMPACT_BUY, NEW_CLEAN_BUY)
    changes += 1
    print("REVERT 5: market impact check removed from _trading_pass — agents compete freely")

# ══════════════════════════════════════════════════════════════════════════════
# ADD: MetaLearner as 33rd trading agent
# ══════════════════════════════════════════════════════════════════════════════

# ── ADD 1: _NeutralMeta dummy class — placed just before class TradingAgent ───
TRADING_AGENT_MARKER = 'class TradingAgent:\n    def __init__(self, cfg: dict, model: StockModel, meta: MetaLearner):'
NEUTRAL_META = (
    'class _NeutralMeta:\n'
    '    """\n'
    '    Dummy meta for the MetaLearner\'s own TradingAgent.\n'
    '    The MetaLearner is its own model, so it doesn\'t need an external meta.\n'
    '    Returns 0.5 (neutral) so the agent trades purely on its own signal.\n'
    '    Effect: eff_cf = cf * 0.7 + 0.5 * 0.3 = cf * 0.7 + 0.15\n'
    '    """\n'
    '    trained = True\n'
    '    preds   = []\n'
    '    profitable_symbols = {}\n'
    '\n'
    '    def meta_confidence(self, x) -> float:\n'
    '        return 0.5\n'
    '\n'
    '    def predict_batch(self, feat_matrix):\n'
    '        import numpy as _np\n'
    '        return _np.full(len(feat_matrix), 0.5, dtype=_np.float32)\n'
    '\n'
    '\n'
)
assert TRADING_AGENT_MARKER in src, "ADD1: TradingAgent class not found"
src = src.replace(TRADING_AGENT_MARKER, NEUTRAL_META + TRADING_AGENT_MARKER)
changes += 1
print("ADD 1: _NeutralMeta dummy class added")

# ── ADD 2: Initialise meta_agent in TradingSystem.__init__ ────────────────────
OLD_AGENTS_INIT = (
    '        self.models = {c["name"]: StockModel(c["name"], c) for c in AGENT_CONFIGS}\n'
    '        self.agents = [TradingAgent(c, self.models[c["name"]], self.meta)\n'
    '                       for c in AGENT_CONFIGS]'
)
NEW_AGENTS_INIT = (
    '        self.models = {c["name"]: StockModel(c["name"], c) for c in AGENT_CONFIGS}\n'
    '        self.agents = [TradingAgent(c, self.models[c["name"]], self.meta)\n'
    '                       for c in AGENT_CONFIGS]\n'
    '        # 33rd agent: the MetaLearner itself trades with its own $1,000\n'
    '        # It uses _NeutralMeta so it relies purely on its own predictions.\n'
    '        # Competing alongside the 32 agents shows how well meta-learning works.\n'
    '        _meta_cfg = {\n'
    '            "name": "MetaLearner", "risk_tolerance": 0.5,\n'
    '            "max_position_pct": 0.2, "min_confidence": 0.55,\n'
    '            "is_clone": False, "take_profit_mult": 1.3,\n'
    '        }\n'
    '        self.meta_agent = TradingAgent(_meta_cfg, self.meta, _NeutralMeta())'
)
assert OLD_AGENTS_INIT in src, "ADD2: agents init not found"
src = src.replace(OLD_AGENTS_INIT, NEW_AGENTS_INIT)
changes += 1
print("ADD 2: meta_agent (33rd trader) initialized in TradingSystem.__init__")

# ── ADD 3: Load meta_agent model in initialize() ─────────────────────────────
OLD_LOAD_MODELS = (
    '        if not self.collect_only:\n'
    '            _log_ram("before model load")\n'
    '            for a in self.agents:\n'
    '                a.model.load()\n'
    '            self.meta.load()\n'
)
NEW_LOAD_MODELS = (
    '        if not self.collect_only:\n'
    '            _log_ram("before model load")\n'
    '            for a in self.agents:\n'
    '                a.model.load()\n'
    '            self.meta.load()\n'
    '            # meta_agent uses same model object as self.meta (already loaded above)\n'
)
assert OLD_LOAD_MODELS in src, "ADD3: model load section not found"
src = src.replace(OLD_LOAD_MODELS, NEW_LOAD_MODELS)
changes += 1
print("ADD 3: meta_agent model load noted (shares self.meta)")

# ── ADD 4: Include meta_agent in load_all_agent_state ────────────────────────
OLD_LOAD_STATE = (
    'def load_all_agent_state(agents: list, meta: MetaLearner):\n'
    '    if not os.path.exists(AGENT_STATE_FILE):\n'
    '        logger.info("No saved agent state — starting fresh.")\n'
    '        return\n'
    '    try:\n'
    '        with open(AGENT_STATE_FILE) as f:\n'
    '            all_state = json.load(f)\n'
    '        for a in agents:\n'
    '            if a.name in all_state:\n'
    '                a.load_state(all_state[a.name])\n'
    '                logger.info(f"  [{a.name}] Restored: capital=${a.capital:.2f} pnl=${a.total_pnl:+.2f}")\n'
    '        meta_s = all_state.get("__meta__", {})\n'
    '        meta.preds              = meta_s.get("preds", [])\n'
    '        meta.profitable_symbols = meta_s.get("profitable_symbols", {})\n'
    '        logger.info("Agent state loaded.")\n'
    '    except Exception as e:\n'
    '        logger.warning(f"Could not load agent state: {e}")'
)
NEW_LOAD_STATE = (
    'def load_all_agent_state(agents: list, meta: MetaLearner, meta_agent=None):\n'
    '    if not os.path.exists(AGENT_STATE_FILE):\n'
    '        logger.info("No saved agent state — starting fresh.")\n'
    '        return\n'
    '    try:\n'
    '        with open(AGENT_STATE_FILE) as f:\n'
    '            all_state = json.load(f)\n'
    '        for a in agents:\n'
    '            if a.name in all_state:\n'
    '                a.load_state(all_state[a.name])\n'
    '                logger.info(f"  [{a.name}] Restored: capital=${a.capital:.2f} pnl=${a.total_pnl:+.2f}")\n'
    '        # Restore MetaLearner trading agent state\n'
    '        if meta_agent and "MetaLearner" in all_state:\n'
    '            meta_agent.load_state(all_state["MetaLearner"])\n'
    '            logger.info(f"  [MetaLearner] Restored: capital=${meta_agent.capital:.2f} pnl=${meta_agent.total_pnl:+.2f}")\n'
    '        meta_s = all_state.get("__meta__", {})\n'
    '        meta.preds              = meta_s.get("preds", [])\n'
    '        meta.profitable_symbols = meta_s.get("profitable_symbols", {})\n'
    '        logger.info("Agent state loaded.")\n'
    '    except Exception as e:\n'
    '        logger.warning(f"Could not load agent state: {e}")'
)
assert OLD_LOAD_STATE in src, "ADD4: load_all_agent_state not found"
src = src.replace(OLD_LOAD_STATE, NEW_LOAD_STATE)
changes += 1
print("ADD 4: meta_agent state loaded in load_all_agent_state")

# ── ADD 5: Include meta_agent in save_all_agent_state ────────────────────────
OLD_SAVE_ALL = (
    'def save_all_agent_state(agents: list, meta: MetaLearner, prices: dict = None):\n'
    '    """Save agent state. Pass prices so open positions are marked to market."""\n'
    '    _set_nice(19)  # P4: lowest OS scheduling priority\n'
    '    _resource_guard.yield_p4("agent_state_save")\n'
    '    state = {a.name: a.save_state(prices) for a in agents}\n'
    '    state["__meta__"] = {"preds": meta.preds, "profitable_symbols": meta.profitable_symbols}'
)
NEW_SAVE_ALL = (
    'def save_all_agent_state(agents: list, meta: MetaLearner, prices: dict = None,\n'
    '                         meta_agent=None):\n'
    '    """Save agent state. Pass prices so open positions are marked to market."""\n'
    '    _set_nice(19)  # P4: lowest OS scheduling priority\n'
    '    _resource_guard.yield_p4("agent_state_save")\n'
    '    state = {a.name: a.save_state(prices) for a in agents}\n'
    '    if meta_agent:\n'
    '        state["MetaLearner"] = meta_agent.save_state(prices)\n'
    '    state["__meta__"] = {"preds": meta.preds, "profitable_symbols": meta.profitable_symbols}'
)
assert OLD_SAVE_ALL in src, "ADD5: save_all_agent_state not found"
src = src.replace(OLD_SAVE_ALL, NEW_SAVE_ALL)
changes += 1
print("ADD 5: meta_agent state saved in save_all_agent_state")

# ── ADD 6: Update all save_all_agent_state call sites in TradingSystem ────────
src = src.replace(
    'save_all_agent_state(self.agents, self.meta, self.dm.prices)',
    'save_all_agent_state(self.agents, self.meta, self.dm.prices, self.meta_agent)'
)
changes += 1
print("ADD 6: All save calls pass meta_agent")

# ── ADD 7: Update load_all_agent_state call site in initialize() ──────────────
src = src.replace(
    'load_all_agent_state(self.agents, self.meta)',
    'load_all_agent_state(self.agents, self.meta, self.meta_agent)'
)
changes += 1
print("ADD 7: load call passes meta_agent")

# ── ADD 8: meta_agent trades in _trading_pass (after the 32-agent loop) ───────
OLD_PASS_END = (
    '        debug_logger.debug(\n'
    '            f"TRADING_PASS | syms={len(syms)} | "\n'
    '            f"cycle={self.cycle} | "\n'
    '            f"impact_tracked={len(self._system_buys)}"\n'
    '        )'
)
NEW_PASS_END = (
    '        # ── MetaLearner (33rd agent) trading pass ──────────────────────────────\n'
    '        if self.meta_agent.model.trained:\n'
    '            self.meta_agent.exit_checks(prices)\n'
    '            self.meta_agent.check_reset(prices)\n'
    '            try:\n'
    '                meta_dirs, meta_cfs = self.meta_agent.model.predict_batch(feat_matrix)\n'
    '                for i, sym in enumerate(syms):\n'
    '                    price = prices.get(sym, 0.0)\n'
    '                    if price <= 0:\n'
    '                        continue\n'
    '                    md  = int(meta_dirs[i])\n'
    '                    mcf = float(meta_cfs[i])\n'
    '                    # MetaLearner uses its own output as both signal and confidence\n'
    '                    eff = mcf * 0.7 + 0.5 * 0.3  # neutral meta of meta\n'
    '                    atr_pct = float(vecs[i][atr_idx]) if atr_idx >= 0 else 0.02\n'
    '                    if md == 1 and eff >= self.meta_agent.min_conf:\n'
    '                        if sym in self.meta_agent.short_positions:\n'
    '                            self.meta_agent.cover(sym, price, reason="signal")\n'
    '                        elif sym not in self.meta_agent.positions and self.allow_long:\n'
    '                            _, pred_pct, _ = self.meta_agent.model.predict(vecs[i])\n'
    '                            self.meta_agent.buy(sym, price, pred_pct or 0.0,\n'
    '                                                mcf, 0.5, eff, atr_pct)\n'
    '                    elif md == 0 and eff >= self.meta_agent.min_conf:\n'
    '                        if sym in self.meta_agent.positions:\n'
    '                            self.meta_agent.sell(sym, price, reason="signal")\n'
    '                        elif sym not in self.meta_agent.short_positions and self.allow_short:\n'
    '                            _, pred_pct, _ = self.meta_agent.model.predict(vecs[i])\n'
    '                            self.meta_agent.short(sym, price, pred_pct or 0.0,\n'
    '                                                  mcf, 0.5, eff)\n'
    '            except Exception as _mpe:\n'
    '                debug_logger.warning(f"META_AGENT_TRADE_ERR | {_mpe}")\n'
    '\n'
    '        debug_logger.debug(\n'
    '            f"TRADING_PASS | syms={len(syms)} | "\n'
    '            f"cycle={self.cycle}"\n'
    '        )'
)

# Check which version of the debug line we have
if 'impact_tracked=' in src:
    assert OLD_PASS_END in src, "ADD8: trading pass end with impact_tracked not found"
    src = src.replace(OLD_PASS_END, NEW_PASS_END)
else:
    OLD_PASS_END2 = (
        '        debug_logger.debug(\n'
        '            f"TRADING_PASS | syms={len(syms)} | "\n'
        '            f"cycle={self.cycle}"\n'
        '        )'
    )
    assert OLD_PASS_END2 in src, "ADD8: trading pass end not found in either form"
    src = src.replace(OLD_PASS_END2,
        '        # ── MetaLearner (33rd agent) trading pass ──────────────────────────────\n'
        '        if self.meta_agent.model.trained:\n'
        '            self.meta_agent.exit_checks(prices)\n'
        '            self.meta_agent.check_reset(prices)\n'
        '            try:\n'
        '                meta_dirs, meta_cfs = self.meta_agent.model.predict_batch(feat_matrix)\n'
        '                for i, sym in enumerate(syms):\n'
        '                    price = prices.get(sym, 0.0)\n'
        '                    if price <= 0: continue\n'
        '                    md  = int(meta_dirs[i])\n'
        '                    mcf = float(meta_cfs[i])\n'
        '                    eff = mcf * 0.7 + 0.5 * 0.3\n'
        '                    atr_pct = float(vecs[i][atr_idx]) if atr_idx >= 0 else 0.02\n'
        '                    if md == 1 and eff >= self.meta_agent.min_conf:\n'
        '                        if sym in self.meta_agent.short_positions:\n'
        '                            self.meta_agent.cover(sym, price, reason="signal")\n'
        '                        elif sym not in self.meta_agent.positions and self.allow_long:\n'
        '                            _, pred_pct, _ = self.meta_agent.model.predict(vecs[i])\n'
        '                            self.meta_agent.buy(sym, price, pred_pct or 0.0, mcf, 0.5, eff, atr_pct)\n'
        '                    elif md == 0 and eff >= self.meta_agent.min_conf:\n'
        '                        if sym in self.meta_agent.positions:\n'
        '                            self.meta_agent.sell(sym, price, reason="signal")\n'
        '                        elif sym not in self.meta_agent.short_positions and self.allow_short:\n'
        '                            _, pred_pct, _ = self.meta_agent.model.predict(vecs[i])\n'
        '                            self.meta_agent.short(sym, price, pred_pct or 0.0, mcf, 0.5, eff)\n'
        '            except Exception as _mpe:\n'
        '                debug_logger.warning(f"META_AGENT_TRADE_ERR | {_mpe}")\n'
        '\n'
        '        debug_logger.debug(\n'
        '            f"TRADING_PASS | syms={len(syms)} | "\n'
        '            f"cycle={self.cycle}"\n'
        '        )'
    )
changes += 1
print("ADD 8: MetaLearner trades in _trading_pass as 33rd agent")

# ── ADD 9: meta_agent in human_report leaderboard ────────────────────────────
# The leaderboard iterates 'agents' -- we just need to add meta_agent
OLD_RANKED = (
    '    ranked = sorted(agents, key=lambda a: a.total_pnl, reverse=True)\n'
    '    for ag in ranked[:15]:'
)
NEW_RANKED = (
    '    # Include MetaLearner trading agent in leaderboard (passed via agents list or kwarg)\n'
    '    _all_agents = list(agents)\n'
    '    ranked = sorted(_all_agents, key=lambda a: a.total_pnl, reverse=True)\n'
    '    for ag in ranked[:16]:  # +1 for MetaLearner'
)
if OLD_RANKED in src:
    src = src.replace(OLD_RANKED, NEW_RANKED)
    changes += 1
    print("ADD 9: leaderboard shows 16 entries (includes MetaLearner slot)")

# ── ADD 10: Pass meta_agent into _write_human_report call sites ───────────────
# The report functions take agents: list -- easiest is to append meta_agent
# at the call sites rather than changing signatures
# Find where _write_human_report is called with self.agents
src = src.replace(
    '_write_human_report(self.agents, self.dm, self.meta)',
    '_write_human_report(self.agents + [self.meta_agent], self.dm, self.meta)'
)
src = src.replace(
    '_write_claude_handoff(self.agents, self.dm, self.meta)',
    '_write_claude_handoff(self.agents + [self.meta_agent], self.dm, self.meta)'
)
changes += 1
print("ADD 10: meta_agent included in reports")

# ── ADD 11: meta_agent startup log ───────────────────────────────────────────
OLD_LOG_TOTAL = '        logger.info(f"Total agents : {NUM_AGENTS}")'
NEW_LOG_TOTAL = (
    '        logger.info(f"Total agents : {NUM_AGENTS}")\n'
    '        logger.info(f"MetaLearner  : trading as 33rd agent ($1,000 capital)")'
)
if OLD_LOG_TOTAL in src:
    src = src.replace(OLD_LOG_TOTAL, NEW_LOG_TOTAL)
    changes += 1
    print("ADD 11: MetaLearner startup log line")

open('StockTrading.py', 'w').write(src)
print(f"\n{changes} total changes.")

import py_compile, tempfile, shutil
tmp = tempfile.mktemp(suffix='.py')
open(tmp,'w').write(src)
try:
    py_compile.compile(tmp, doraise=True)
    print(f"SYNTAX PASS — {len(src.splitlines())} lines")
except py_compile.PyCompileError as e:
    print(f"SYNTAX ERROR: {e}")
    shutil.copy('StockTrading.py.bak', 'StockTrading.py')
    print("RESTORED FROM BACKUP")
    sys.exit(1)
finally:
    os.unlink(tmp)
