//AI is-this-loss
>TL;DR
We are given a forward-pass oracle for a trained neural net and a piecewise training loss with hidden scalars:
[ \mathcal{L}(x, y) = \begin{cases} \alpha \cdot \mathcal{L}{\text{sup}}(f(x), y) & \text{if } |z(x)|2 \le \tau \ \beta \cdot \mathcal{L}{\text{sup}}(f(x), y) + \gamma \cdot \mathcal{L}{\text{contrast}}(z(x)) + \delta \cdot |\nabla_x f(x)|_2^2 & \text{otherwise} \end{cases} ]
Key idea: use the oracle to get (\mathcal{L}(x,0)) and (\mathcal{L}(x,1)) for the same input (x), while computing (f(x)), (z(x)) and (|\nabla_x f(x)|_2^2) locally from the provided ONNX weights. That gives enough equations to recover the scalars.
This repo contains a single automated solver: solve_loss_params.py.
>Challenge interface
Remote oracle:
ncat --ssl loss.chalz.nitectf25.live 1337
It expects:
- 8 floats (the input vector) + a label (0/1)
- Optionally add
latentto print the latent vector.
Example:
0.1 -0.3 0.7 0.2 -0.5 0.9 0.4 -0.1 1 latent
The oracle prints:
Output: <yhat> Loss: <L>
So per input we can get (y=0) and (y=1) losses as many times as we want.
>Local model analysis (the “local success” part)
The handout provides the exact model architecture in isthisloss_handout/model.py:
- input: (x\in\mathbb{R}^8)
- hidden: 16 tanh
- latent: (z\in\mathbb{R}^6) (tanh)
- output: scalar logit (\hat{y}) (linear)
The weights are in isthisloss_handout/weights.onnx.
Important: local ONNX matches the remote oracle
If we run the ONNX model locally for a sample input, it reproduces the oracle’s Output exactly (floating point rounding aside). That means:
- the oracle uses the same network
- we can compute (z(x)), (|z|_2), and the exact input-gradient regularizer locally
Computing (|\nabla_x f(x)|_2^2) locally
The model is:
- (h = \tanh(W_1 x + b_1))
- (z = \tanh(W_2 h + b_2))
- (f(x) = W_3 z + b_3)
By the chain rule:
[ \frac{\partial f}{\partial x} = W_3,\mathrm{diag}(1-z^2),W_2,\mathrm{diag}(1-h^2),W_1 ]
This is a 1×8 Jacobian. Square it and sum to get (|\nabla_x f(x)|_2^2).
This gradient is computed in the solver without PyTorch (pure NumPy), using the ONNX weights.
>Remote solving strategy (the “remote success” part)
We want (\alpha,\beta,\gamma,\delta,\tau).
Step 1 — Identify (\mathcal{L}_{sup}) and solve for scaling
We can query the same (x) with labels 0 and 1. Subtract the two losses.
In the OTHER branch:
[ \mathcal{L}(x,0)-\mathcal{L}(x,1) = \beta,(\mathcal{L}{sup}(f(x),0)-\mathcal{L}{sup}(f(x),1)) ]
The contrast term and gradient term cancel because they don’t depend on the label.
Empirically, the only supervised loss that makes the scaling (\beta) constant across many random inputs is BCE-with-logits:
[ \mathcal{L}{sup}(y,1)=\mathrm{softplus}(-y),\quad \mathcal{L}{sup}(y,0)=\mathrm{softplus}(y) ]
So:
[ \beta = \frac{\mathcal{L}(x,0)-\mathcal{L}(x,1)}{\mathcal{L}{sup}(f(x),0)-\mathcal{L}{sup}(f(x),1)} ]
Using ~20 random inputs, the median (\beta) stabilizes at 0.5.
Step 2 — Detect which branch we are in (ALPHA vs OTHER)
Define the residual after subtracting the supervised part:
[ r_y(x) = \mathcal{L}(x,y) - \beta,\mathcal{L}_{sup}(f(x),y) ]
If we are in the OTHER branch, then:
- (r_0(x)\approx r_1(x)) (label-independent residual)
If we are in the ALPHA branch, then:
- (\mathcal{L}(x,y)=\alpha,\mathcal{L}_{sup}(f(x),y))
- so subtracting (\beta,\mathcal{L}_{sup}) leaves a residual that does depend on (y)
In practice the solver uses:
- OTHER if
abs(r0 - r1) < 1e-4, else ALPHA.
Step 3 — Force the ALPHA branch by minimizing (|z(x)|)
Random inputs almost always have (|z|) large (tanh saturates), so you won’t naturally hit (|z|\le\tau).
Instead, we actively search for an input with small latent norm.
A trick that works extremely well:
- optimize to make the pre-tanh latent (u = W_2 h + b_2) small
- because if (u\approx 0), then (z=\tanh(u)\approx 0) too
The solver performs Adam on (|u(x)|_2^2) using an analytic gradient.
This reliably produces a point in the ALPHA branch.
Step 4 — Recover (\alpha)
Once we have an ALPHA-branch point, the loss is:
[ \mathcal{L}(x,y)=\alpha,\mathcal{L}_{sup}(f(x),y) ]
So:
[ \alpha \approx \frac{\mathcal{L}(x,0)}{\mathcal{L}{sup}(f(x),0)} \approx \frac{\mathcal{L}(x,1)}{\mathcal{L}{sup}(f(x),1)} ]
This comes out extremely close to 1.0.
Step 5 — Recover (\tau) by bracketing the branch boundary
Now we have a confirmed ALPHA point. We want the threshold (\tau).
Approach:
- locally increase (|z(x)|) using gradient ascent on (|z|^2)
- after each step, query the oracle to see if we flipped to OTHER
- once we have an ALPHA point and a nearby OTHER point, do bisection on the segment between them
This yields a tight bracket:
- (\tau \approx 0.37)
Step 6 — Recover (\gamma) and (\delta) on OTHER-branch points
For OTHER-branch points:
[ \mathcal{L}(x,y)=\beta,\mathcal{L}{sup}(f(x),y) + \gamma,\mathcal{L}{contrast}(z(x)) + \delta,|\nabla_x f(x)|_2^2 ]
Compute the averaged residual:
[ r(x)=\tfrac12\Big(\mathcal{L}(x,0)-\beta\mathcal{L}{sup}(f(x),0) + \mathcal{L}(x,1)-\beta\mathcal{L}{sup}(f(x),1)\Big) ]
Then:
[ r(x)= \gamma,\mathcal{L}_{contrast}(z(x)) + \delta,|\nabla_x f(x)|_2^2 ]
Empirically, (r(x)) is perfectly explained by:
- a constant term (-0.3) (the contrast contribution)
- plus (0.05\cdot|\nabla_x f(x)|_2^2)
So (\delta = 0.05), and the observed (\gamma\mathcal{L}_{contrast}) equals (-0.3).
Important subtlety:
- the oracle only reveals the product (\gamma\cdot\mathcal{L}_{contrast}(z))
- if (\mathcal{L}_{contrast}) is implemented with an extra minus sign, (\gamma) would flip
That’s why the solver prints both the primary and the sign-flipped candidate.
>Final recovered parameters (rounded to required precision)
From the automated solver:
- (\alpha = 1.0)
- (\beta = 0.5)
- (\gamma = -0.3) (or (+0.3) if the contest’s (\mathcal{L}_{contrast}) definition has a global sign flip)
- (\delta = 0.05)
- (\tau = 0.37)
Candidate flags:
nite{1.0_0.5_-0.3_0.05_0.37}nite{1.0_0.5_0.3_0.05_0.37}
>Repro / How to run
From the repo root:
python3 -m venv .venv
. .venv/bin/activate
pip install -U pip wheel
pip install numpy onnx onnxruntime
python3 solve_loss_params.py
It will:
- connect to the oracle
- find an ALPHA-branch point by minimizing (|u|)
- infer (\alpha) and (\beta)
- bracket and bisect (\tau)
- regress (\gamma\mathcal{L}_{contrast}) and (\delta)
- print the submission string(s)
>Full solver code
Below is the exact solver used (verbatim):
import re
import ssl
import socket
from dataclasses import dataclass
from typing import Callable, List, Tuple, Optional
import numpy as np
import onnx
import onnxruntime as ort
from onnx import numpy_helper
HOST = "loss.chalz.nitectf25.live"
PORT = 1337
@dataclass
class Sample:
x: np.ndarray # (8,)
yhat: float
z: np.ndarray # (6,)
znorm: float
gradnorm2_logit: float
gradnorm2_sigmoid: float
loss0: float
loss1: float
def softplus(x: np.ndarray | float) -> np.ndarray | float:
# numerically stable softplus
x = np.asarray(x, dtype=np.float64)
return np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0)
def bce_with_logits(yhat: float, label: int) -> float:
# label in {0,1}
y = float(yhat)
if label == 1:
return float(softplus(-y))
return float(softplus(y))
def mse(yhat: float, label: int) -> float:
return float((yhat - float(label)) ** 2)
def hinge(yhat: float, label: int) -> float:
# label {0,1} -> {-1,+1}
t = 1.0 if label == 1 else -1.0
return float(max(0.0, 1.0 - t * yhat))
class Oracle:
def __init__(self, host: str = HOST, port: int = PORT):
self.host = host
self.port = port
self.ctx = ssl.create_default_context()
self.ctx.check_hostname = False
self.ctx.verify_mode = ssl.CERT_NONE
self.sock: Optional[socket.socket] = None
self.buf = b""
self.re_out = re.compile(r"Output:\\s*([-+0-9.]+)\\s+Loss:\\s*([-+0-9.]+)")
def connect(self) -> None:
if self.sock is not None:
return
s = self.ctx.wrap_socket(socket.socket(), server_hostname=self.host)
s.settimeout(10)
s.connect((self.host, self.port))
self.sock = s
# read initial banner (best-effort)
self._recv_some(timeout_ok=True)
def close(self) -> None:
if self.sock is not None:
try:
self.sock.close()
finally:
self.sock = None
def _recv_some(self, timeout_ok: bool = False) -> None:
assert self.sock is not None
try:
chunk = self.sock.recv(4096)
except (TimeoutError, socket.timeout):
if timeout_ok:
return
raise
if chunk:
self.buf += chunk
def query_loss(self, x: np.ndarray, label: int) -> Tuple[float, float]:
"""Returns (output, loss)."""
self.connect()
assert self.sock is not None
parts = [f"{float(v):.8f}" for v in x.tolist()] + [str(int(label))]
line = " ".join(parts) + "\\n"
self.sock.sendall(line.encode())
# read until we see an Output/Loss line AFTER this send
# (service may print instructions first)
text = self.buf.decode("utf-8", "replace")
start_idx = len(text)
for _ in range(300):
text = self.buf.decode("utf-8", "replace")
matches = list(self.re_out.finditer(text, pos=start_idx))
if matches:
m = matches[-1]
out = float(m.group(1))
loss = float(m.group(2))
end = m.end()
self.buf = text[end:].encode("utf-8")
return out, loss
self._recv_some(timeout_ok=False)
raise RuntimeError("Failed to parse oracle response")
class LocalModel:
def __init__(self, onnx_path: str):
self.sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
self.in_name = self.sess.get_inputs()[0].name
m = onnx.load(onnx_path)
init = {t.name: numpy_helper.to_array(t).astype(np.float64) for t in m.graph.initializer}
self.W1 = init["fc1.weight"]
self.b1 = init["fc1.bias"]
self.W2 = init["fc2.weight"]
self.b2 = init["fc2.bias"]
self.W3 = init["fc3.weight"]
self.b3 = init["fc3.bias"]
def forward(self, x: np.ndarray) -> Tuple[float, np.ndarray]:
x = np.asarray(x, dtype=np.float32)[None, :]
y, z = self.sess.run(None, {self.in_name: x})
return float(y[0, 0]), z[0].astype(np.float64)
def forward_intermediates(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
x = np.asarray(x, dtype=np.float64)
v = self.W1 @ x + self.b1
h = np.tanh(v)
u = self.W2 @ h + self.b2
z = np.tanh(u)
return h, z, v, u
def gradnorm2(self, x: np.ndarray) -> float:
h, z, _, _ = self.forward_intermediates(x)
Dz = np.diag(1.0 - z * z) # 6x6
Dh = np.diag(1.0 - h * h) # 16x16
J = self.W3 @ Dz @ self.W2 @ Dh @ self.W1 # 1x8
g = J.reshape(-1)
return float(np.dot(g, g))
@staticmethod
def sigmoid(t: float) -> float:
return float(1.0 / (1.0 + np.exp(-float(t))))
def gradnorm2_sigmoid(self, x: np.ndarray) -> float:
# If the regularizer uses f(x)=sigmoid(logit), then df/dx = sigmoid'(logit) * dlogit/dx.
yhat, _ = self.forward(x)
g2 = self.gradnorm2(x)
p = self.sigmoid(yhat)
scale = (p * (1.0 - p)) ** 2
return float(scale * g2)
def craft_low_znorm_x(local: LocalModel, restarts: int = 20, iters: int = 2500) -> np.ndarray:
W1, b1, W2, b2 = local.W1, local.b1, local.W2, local.b2
def forward_z(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
v = W1 @ x + b1
h = np.tanh(v)
u = W2 @ h + b2
z = np.tanh(u)
return z, h
def grad_znorm2(x: np.ndarray) -> np.ndarray:
z, h = forward_z(x)
Dz = np.diag(1.0 - z * z)
Dh = np.diag(1.0 - h * h)
J = Dz @ W2 @ Dh @ W1 # 6x8
return (2.0 * z) @ J
best = None
for seed in range(restarts):
rng = np.random.default_rng(seed)
x = rng.normal(0, 1, size=8).astype(np.float64)
lr = 0.05
for it in range(iters):
z, _ = forward_z(x)
g = grad_znorm2(x)
x = x - lr * g
if it in (500, 1000, 1500, 2000):
lr *= 0.7
yhat, z = local.forward(x)
zn = float(np.linalg.norm(z))
if best is None or zn < best[0]:
best = (zn, x)
assert best is not None
return best[1]
def craft_low_unorm_x(local: LocalModel, restarts: int = 200, iters: int = 50000) -> np.ndarray:
"""Search for x that makes the *pre-tanh* latent u small (and thus ||z|| small).
This reliably finds points in the alpha-branch for this challenge.
"""
W1, b1, W2, b2 = local.W1, local.b1, local.W2, local.b2
def forward_u(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
v = W1 @ x + b1
h = np.tanh(v)
u = W2 @ h + b2
return h, u, v
def grad_unorm2(x: np.ndarray) -> np.ndarray:
h, u, _ = forward_u(x)
Dh = np.diag(1.0 - h * h)
J = W2 @ Dh @ W1 # 6x8
return (2.0 * u) @ J
best = None
rng = np.random.default_rng(77)
for _ in range(restarts):
x = rng.normal(0.0, 10.0, size=8).astype(np.float64)
m = np.zeros_like(x)
v = np.zeros_like(x)
lr = 0.05
for t in range(1, iters + 1):
g = grad_unorm2(x)
m = 0.9 * m + 0.1 * g
v = 0.999 * v + 0.001 * (g * g)
mhat = m / (1.0 - 0.9**t)
vhat = v / (1.0 - 0.999**t)
x = x - lr * mhat / (np.sqrt(vhat) + 1e-8)
if t in (15000, 30000, 45000):
lr *= 0.5
_, z = local.forward(x)
zn = float(np.linalg.norm(z))
if best is None or zn < best[0]:
best = (zn, x)
assert best is not None
return best[1]
def infer_beta(samples: List[Sample], lsup: Callable[[float, int], float]) -> float:
scales = []
for s in samples:
dloss = s.loss0 - s.loss1
dL = lsup(s.yhat, 0) - lsup(s.yhat, 1)
if abs(dL) < 1e-12:
continue
scales.append(dloss / dL)
if not scales:
raise RuntimeError("Unable to infer beta from samples")
return float(np.median(scales))
def residuals_for_beta(samples: List[Sample], lsup: Callable[[float, int], float], beta: float) -> np.ndarray:
# label-independent residual in the otherwise-branch:
# r = gamma*C(z) + delta*G(x)
r = []
for s in samples:
r0 = s.loss0 - beta * lsup(s.yhat, 0)
r1 = s.loss1 - beta * lsup(s.yhat, 1)
r.append(0.5 * (r0 + r1))
return np.asarray(r, dtype=np.float64)
def fit_gamma_delta(
samples: List[Sample],
r: np.ndarray,
contrast_fn: Callable[[np.ndarray], float],
use_mask: Optional[np.ndarray] = None,
) -> Tuple[float, float, float, float]:
if use_mask is None:
use_mask = np.ones(len(samples), dtype=bool)
A = []
b = []
for s, ri, ok in zip(samples, r.tolist(), use_mask.tolist()):
if not ok:
continue
A.append([contrast_fn(s.z), float(s.gradnorm2_logit)])
b.append(ri)
A = np.asarray(A, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
if len(b) < 5:
raise RuntimeError("Not enough samples to fit gamma/delta")
xhat, *_ = np.linalg.lstsq(A, b, rcond=None)
gamma, delta = float(xhat[0]), float(xhat[1])
pred = A @ xhat
rmse = float(np.sqrt(np.mean((pred - b) ** 2)))
return gamma, delta, 0.0, rmse
def fit_gamma_delta_with_intercept(
samples: List[Sample],
r: np.ndarray,
contrast_fn: Callable[[np.ndarray], float],
grad_kind: str,
use_mask: Optional[np.ndarray] = None,
) -> Tuple[float, float, float, float]:
if use_mask is None:
use_mask = np.ones(len(samples), dtype=bool)
A = []
b = []
for s, ri, ok in zip(samples, r.tolist(), use_mask.tolist()):
if not ok:
continue
g2 = s.gradnorm2_logit if grad_kind == "logit" else s.gradnorm2_sigmoid
A.append([contrast_fn(s.z), float(g2), 1.0])
b.append(ri)
A = np.asarray(A, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
if len(b) < 8:
raise RuntimeError("Not enough samples to fit gamma/delta")
xhat, *_ = np.linalg.lstsq(A, b, rcond=None)
gamma, delta, k = float(xhat[0]), float(xhat[1]), float(xhat[2])
pred = A @ xhat
rmse = float(np.sqrt(np.mean((pred - b) ** 2)))
return gamma, delta, k, rmse
def estimate_tau_grid(
samples: List[Sample],
r: np.ndarray,
contrast_fn: Callable[[np.ndarray], float],
tau_values: np.ndarray,
) -> Tuple[float, float, float, float, float, str]:
# Model (assuming alpha==beta):
# if znorm <= tau: residual = 0
# else: residual = gamma*C(z) + delta*G(x)
best = None
for tau in tau_values.tolist():
above = np.array([s.znorm > tau for s in samples], dtype=bool)
if np.count_nonzero(above) < 8:
continue
for grad_kind in ("logit", "sigmoid"):
try:
gamma, delta, k, _ = fit_gamma_delta_with_intercept(samples, r, contrast_fn, grad_kind=grad_kind, use_mask=above)
except Exception:
continue
# score across all points, enforcing 0 residual below tau
pred = np.zeros_like(r)
for i, s in enumerate(samples):
if s.znorm > tau:
g2 = s.gradnorm2_logit if grad_kind == "logit" else s.gradnorm2_sigmoid
pred[i] = gamma * contrast_fn(s.z) + delta * g2 + k
rmse = float(np.sqrt(np.mean((pred - r) ** 2)))
if best is None or rmse < best[0]:
best = (rmse, float(tau), gamma, delta, k, grad_kind)
if best is None:
raise RuntimeError("Failed to estimate tau")
rmse, tau, gamma, delta, k, grad_kind = best
return tau, gamma, delta, k, rmse, grad_kind
def main():
local = LocalModel("isthisloss_handout/weights.onnx")
oracle = Oracle()
# The oracle uses BCE-with-logits as supervised loss.
lsup_fn = bce_with_logits
def scale_from_losses(yhat: float, L0: float, L1: float) -> float:
dloss = L0 - L1
dL = lsup_fn(yhat, 0) - lsup_fn(yhat, 1)
return float(dloss / dL)
def is_other_branch(yhat: float, L0: float, L1: float, beta: float) -> bool:
# In OTHER branch, residual is label-independent.
r0 = L0 - beta * lsup_fn(yhat, 0)
r1 = L1 - beta * lsup_fn(yhat, 1)
return abs(r0 - r1) < 1e-4
# 1) Find a point in the alpha-branch by making ||u|| small.
x_alpha = craft_low_unorm_x(local)
y_alpha, z_alpha = local.forward(x_alpha)
_, L0a = oracle.query_loss(x_alpha, 0)
_, L1a = oracle.query_loss(x_alpha, 1)
alpha = scale_from_losses(y_alpha, L0a, L1a)
# 2) Find beta from a handful of random OTHER points.
rng = np.random.default_rng(1337)
beta_scales = []
for _ in range(20):
x = rng.uniform(-2.0, 2.0, size=8).astype(np.float64)
yhat, _ = local.forward(x)
_, L0 = oracle.query_loss(x, 0)
_, L1 = oracle.query_loss(x, 1)
s = scale_from_losses(yhat, L0, L1)
# beta should be the smaller scale (alpha>beta in this challenge)
beta_scales.append(s)
beta = float(np.median(beta_scales))
print("\\nInferred alpha:", alpha)
print("Inferred beta :", beta)
# 3) Bracket tau by walking upward from x_alpha (increasing ||z||) until the oracle flips to OTHER.
x_lo = x_alpha.copy() # ALPHA
x_hi = None
for _ in range(2000):
g = None
# gradient ascent on ||z||^2 (local, no oracle)
h, z, _, _ = local.forward_intermediates(x_lo)
Dz = np.diag(1.0 - z * z)
Dh = np.diag(1.0 - h * h)
J = Dz @ local.W2 @ Dh @ local.W1 # 6x8
g = (2.0 * z) @ J
gn = float(np.linalg.norm(g) + 1e-12)
x_try = x_lo + 0.01 * (g / gn)
yhat, _ = local.forward(x_try)
_, L0 = oracle.query_loss(x_try, 0)
_, L1 = oracle.query_loss(x_try, 1)
if is_other_branch(yhat, L0, L1, beta):
x_hi = x_try
break
x_lo = x_try
if x_hi is None:
raise RuntimeError("Failed to find OTHER-branch point when bracketing tau")
# 4) Bisection between x_lo (ALPHA) and x_hi (OTHER) to pinpoint tau.
def branch_kind(x: np.ndarray) -> str:
yhat, _ = local.forward(x)
_, L0 = oracle.query_loss(x, 0)
_, L1 = oracle.query_loss(x, 1)
return "OTHER" if is_other_branch(yhat, L0, L1, beta) else "ALPHA"
assert branch_kind(x_lo) == "ALPHA"
assert branch_kind(x_hi) == "OTHER"
for _ in range(22):
xm = 0.5 * (x_lo + x_hi)
if branch_kind(xm) == "ALPHA":
x_lo = xm
else:
x_hi = xm
tau_lo = float(np.linalg.norm(local.forward(x_lo)[1]))
tau_hi = float(np.linalg.norm(local.forward(x_hi)[1]))
tau = 0.5 * (tau_lo + tau_hi)
print("Tau bracket:", tau_lo, tau_hi)
# 5) Fit gamma (intercept) and delta (slope) on OTHER points: r = gamma + delta*||grad||^2.
other_pts = []
for _ in range(40):
x = rng.uniform(-2.0, 2.0, size=8).astype(np.float64)
_, z = local.forward(x)
if float(np.linalg.norm(z)) <= tau_hi:
continue
yhat, _ = local.forward(x)
_, L0 = oracle.query_loss(x, 0)
_, L1 = oracle.query_loss(x, 1)
if not is_other_branch(yhat, L0, L1, beta):
continue
r = 0.5 * ((L0 - beta * lsup_fn(yhat, 0)) + (L1 - beta * lsup_fn(yhat, 1)))
g2 = local.gradnorm2(x)
other_pts.append((g2, r))
if len(other_pts) >= 12:
break
if len(other_pts) < 8:
raise RuntimeError("Not enough OTHER points to fit gamma/delta")
A = np.asarray([[1.0, g2] for g2, _ in other_pts], dtype=np.float64)
b = np.asarray([r for _, r in other_pts], dtype=np.float64)
coef, *_ = np.linalg.lstsq(A, b, rcond=None)
gamma = float(coef[0])
delta = float(coef[1])
oracle.close()
print("\\nRecovered (raw):")
print(" alpha:", alpha)
print(" beta :", beta)
print(" gamma:", gamma)
print(" delta:", delta)
print(" tau :", tau)
# Note: the oracle only reveals gamma*L_contrast(z). Empirically, the contrast contribution is constant
# across x, so L_contrast appears constant up to sign/normalization.
# If L_contrast(z) == 1, then gamma == -0.3. If L_contrast(z) == -1, then gamma == +0.3.
print("\\nObserved constant contrast contribution (gamma*L_contrast):", gamma)
alpha_r = round(alpha, 1)
beta_r = round(beta, 1)
gamma_r = round(gamma, 1)
delta_r = round(delta, 2)
tau_r = round(tau, 2)
print("\\nRounded parameters (per submission format):")
print(" ", alpha_r, beta_r, gamma_r, delta_r, f"{tau_r:.2f}")
print(f"\\nFLAG: nite{{{alpha_r}_{beta_r}_{gamma_r}_{delta_r}_{tau_r:.2f}}}")
gamma_alt_r = round(-gamma, 1)
print(f"ALT FLAG (if L_contrast is negated): nite{{{alpha_r}_{beta_r}_{gamma_alt_r}_{delta_r}_{tau_r:.2f}}}")
if __name__ == "__main__":
main()