diff --git a/KERNEL_REV b/KERNEL_REV index cb1dbc3a4..af059324d 100644 --- a/KERNEL_REV +++ b/KERNEL_REV @@ -1 +1 @@ -b4d88220cdfad8dba1cfa89892269342ae26feeb +101aa465e71991eec98102bba77aad2f7ad8faed diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index db5e025ee..7cdd484fa 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -63,6 +63,14 @@ logger = logging.getLogger(__name__) +# Headers the kernel manages itself and that the connector must NOT +# forward via ``http_headers`` (lower-cased for case-insensitive match): +# ``authorization`` (the kernel applies the auth provider's token) and +# ``x-databricks-org-id`` (the kernel re-derives it from the ``?o=`` in +# http_path). Forwarding either is redundant and trips the kernel's +# per-request skip-and-warn. +_KERNEL_MANAGED_HEADERS = frozenset({"authorization", "x-databricks-org-id"}) + # ─── Client ───────────────────────────────────────────────────────────────── @@ -91,13 +99,19 @@ def __init__( ): # ``ssl_options`` is translated to the kernel's ``tls_*`` # Session kwargs in ``open_session`` (custom CA, verify - # toggles, mTLS client cert/key). ``http_headers`` / - # ``http_client`` / ``port`` are still accept-and-ignore — the - # kernel manages its own HTTP stack. + # toggles, mTLS client cert/key). ``http_headers`` is forwarded + # to the kernel as custom request headers (it carries the + # connector's composed ``User-Agent`` + any caller headers + the + # SPOG ``x-databricks-org-id``). ``http_client`` / ``port`` are + # still accept-and-ignore — the kernel manages its own HTTP + # stack. self._server_hostname = server_hostname self._http_path = http_path self._auth_provider = auth_provider self._ssl_options = ssl_options + # Caller / connector HTTP headers (list of (name, value) pairs). + # Forwarded to the kernel Session in ``open_session``. + self._http_headers = http_headers or [] # Raw auth-relevant connect() kwargs (auth_type, # oauth_client_id/secret, redirect port, credentials_provider). # The kernel auth bridge needs these to build OAuth kwargs — the @@ -175,19 +189,45 @@ def open_session( session_conf: Optional[Dict[str, str]] = None if session_configuration: session_conf = {k: str(v) for k, v in session_configuration.items()} - # Build auth kwargs here (not in ``__init__``) so the bearer - # token has the shortest possible in-process lifetime: a - # local kwargs dict is GC-eligible the moment this method - # returns, regardless of whether the kernel ``Session()`` - # call succeeded or raised. - auth_kwargs = kernel_auth_kwargs(self._auth_provider, self._auth_options) - # Translate the connector's SSLOptions into the kernel's - # ``tls_*`` Session kwargs. Empty when TLS is left at defaults. - tls_kwargs = _kernel_tls_kwargs(self._ssl_options) - # Translate the connector's ``_retry_*`` kwargs into the kernel's - # ``retry_*`` Session kwargs. Empty when retry is left at defaults. - retry_kwargs = _kernel_retry_kwargs(self._retry_options) + # The kwarg builds run INSIDE the try so the ``finally`` scrub + # below always fires — including when ``kernel_auth_kwargs`` + # itself raises mid-build (e.g. an OAuth token-exchange failure + # while the M2M secret is in hand). Pre-declared empty so the + # ``finally`` can reference them unconditionally even on an early + # raise. Building here (not in ``__init__``) keeps the bearer + # token's in-process lifetime as short as possible. + auth_kwargs: Dict[str, Any] = {} + tls_kwargs: Dict[str, Any] = {} try: + auth_kwargs = kernel_auth_kwargs(self._auth_provider, self._auth_options) + # Translate the connector's SSLOptions into the kernel's + # ``tls_*`` Session kwargs. Empty when TLS is at defaults. + tls_kwargs = _kernel_tls_kwargs(self._ssl_options) + # Translate the connector's ``_retry_*`` kwargs into the + # kernel's ``retry_*`` kwargs. Empty when at defaults. + retry_kwargs = _kernel_retry_kwargs(self._retry_options) + # Forward caller / connector HTTP headers. The kernel applies + # them on every request; a caller ``User-Agent`` is appended + # to the kernel's base UA. Only pass the kwarg when there's + # something to send. + # + # We drop ``Authorization`` and ``x-databricks-org-id`` here, + # before they reach the kernel, for two reasons: (1) the + # kernel manages both itself (auth from the provider; org-id + # re-derived from the ``?o=`` in http_path), so forwarding + # them is redundant; (2) the kernel skips-and-warns those two + # names on every request, so forwarding the SPOG org-id the + # connector always injects would spam a warning per request. + # This double-walls the kernel's own reserved-name skip. + http_headers_kwargs: Dict[str, Any] = {} + if self._http_headers: + forwarded = [ + (str(k), str(v)) + for k, v in self._http_headers + if str(k).lower() not in _KERNEL_MANAGED_HEADERS + ] + if forwarded: + http_headers_kwargs["http_headers"] = forwarded self._kernel_session = _kernel.Session( host=self._server_hostname, http_path=self._http_path, @@ -208,6 +248,7 @@ def open_session( **auth_kwargs, **tls_kwargs, **retry_kwargs, + **http_headers_kwargs, ) except Exception as exc: raise _wrap_kernel_exception("open_session", exc) from exc @@ -304,10 +345,6 @@ def execute_command( ) -> Union["ResultSet", None]: if self._kernel_session is None: raise InterfaceError("Cannot execute_command without an open session.") - if query_tags: - raise NotSupportedError( - "Statement-level query_tags are not yet supported on the kernel backend." - ) try: stmt = self._kernel_session.statement() @@ -321,6 +358,13 @@ def execute_command( try: try: stmt.set_sql(operation) + if query_tags: + # Per-statement query tags. The kernel serialises the + # dict (None value -> bare key) into the SEA + # `query_tags` statement conf. ``query_tags`` is + # already ``Dict[str, Optional[str]]`` from the + # connector, which the kernel accepts directly. + stmt.set_query_tags(query_tags) if parameters: bind_tspark_params(stmt, parameters) if async_op: diff --git a/tests/e2e/test_kernel_backend.py b/tests/e2e/test_kernel_backend.py index ff1a26c8f..1e61bd7b8 100644 --- a/tests/e2e/test_kernel_backend.py +++ b/tests/e2e/test_kernel_backend.py @@ -332,3 +332,32 @@ def test_parameterized_query_decimal(conn): rows = cur.fetchall() # Server echoes back as decimal.Decimal. assert str(rows[0][0]) == "-123.45" + + +def test_query_tags_round_trip(kernel_conn_params): + """Per-statement query_tags are forwarded to the kernel and accepted + by the server. Smoke-level: a malformed query_tags conf would fail + the execute. (query.history ingestion lag makes a sync tag-readback + assertion infeasible.)""" + with sql.connect(**kernel_conn_params) as c: + with c.cursor() as cur: + cur.execute( + "SELECT 1 AS n", + query_tags={"team": "platform", "production": None}, + ) + assert cur.fetchall()[0][0] == 1 + + +def test_user_agent_entry_and_http_headers_round_trip(kernel_conn_params): + """A connection with user_agent_entry (folded into the connector's + User-Agent, then appended to the kernel base UA) and a custom HTTP + header opens and queries cleanly. Replacing the kernel base UA would + break the SEA result disposition (HTTP 400); appending preserves it + — this exercises that end-to-end.""" + params = dict(kernel_conn_params) + params["user_agent_entry"] = "kernel-e2e-app" + params["http_headers"] = [("X-Kernel-E2E", "yes")] + with sql.connect(**params) as c: + with c.cursor() as cur: + cur.execute("SELECT 1 AS n") + assert cur.fetchall()[0][0] == 1 diff --git a/tests/unit/test_kernel_client.py b/tests/unit/test_kernel_client.py index 1a1773b90..8cff9b3d4 100644 --- a/tests/unit/test_kernel_client.py +++ b/tests/unit/test_kernel_client.py @@ -332,26 +332,43 @@ def test_execute_command_forwards_parameters_to_bind_param(): assert stmt.execute.called -def test_execute_command_rejects_query_tags(): +def test_execute_command_forwards_query_tags(): + """Statement-level query_tags are forwarded to the kernel statement + via set_query_tags (the kernel serialises them into the SEA + query_tags conf). Previously rejected with NotSupportedError; now + wired (kernel PR adding Statement.set_query_tags).""" c = _make_client() c._kernel_session = MagicMock() cursor = MagicMock() cursor.arraysize = 100 cursor.buffer_size_bytes = 1024 - with pytest.raises(NotSupportedError, match="query_tags"): - c.execute_command( - operation="SELECT 1", - session_id=MagicMock(), - max_rows=1, - max_bytes=1, - lz4_compression=False, - cursor=cursor, - use_cloud_fetch=False, - parameters=[], - async_op=False, - enforce_embedded_schema_correctness=False, - query_tags={"team": "x"}, - ) + + stmt = MagicMock() + stmt.set_sql = MagicMock() + stmt.set_query_tags = MagicMock() + stmt.execute.return_value = MagicMock( + statement_id="stmt-id", + arrow_schema=MagicMock(return_value=pa.schema([("x", pa.int64())])), + ) + c._kernel_session.statement.return_value = stmt + + tags = {"team": "platform", "production": None} + c.execute_command( + operation="SELECT 1", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + query_tags=tags, + ) + + stmt.set_query_tags.assert_called_once_with(tags) + assert stmt.execute.called def test_get_columns_accepts_none_catalog(): @@ -1015,3 +1032,75 @@ def test_retry_delay_default_has_no_mapping(self): # recognised key here — it has no kernel equivalent. out = kernel_client._kernel_retry_kwargs({"retry_delay_default": 5.0}) assert out == {} + + +class TestKernelHttpHeadersForwarding: + """http_headers (the connector's caller headers + composed + User-Agent + SPOG org-id) are forwarded to the kernel Session as the + ``http_headers`` kwarg. The kernel applies them per request (its own + Authorization / org-id win; a caller User-Agent is appended to the + kernel base UA).""" + + def _open_capturing(self, monkeypatch, http_headers): + captured = {} + + def fake_session(**kw): + captured.update(kw) + sess = MagicMock() + sess.session_id = "sess-id" + return sess + + monkeypatch.setattr(kernel_client._kernel, "Session", fake_session) + c = kernel_client.KernelDatabricksClient( + server_hostname="example.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc", + auth_provider=AccessTokenAuthProvider("dapi-test"), + ssl_options=None, + http_headers=http_headers, + ) + c.open_session(session_configuration=None, catalog=None, schema=None) + return captured + + def test_http_headers_forwarded_to_kernel_session(self, monkeypatch): + headers = [ + ("User-Agent", "PyDatabricksSqlConnector/4.0 (myentry)"), + ("X-Custom", "v1"), + ] + captured = self._open_capturing(monkeypatch, headers) + assert captured.get("http_headers") == [ + ("User-Agent", "PyDatabricksSqlConnector/4.0 (myentry)"), + ("X-Custom", "v1"), + ] + + def test_no_http_headers_omits_kwarg(self, monkeypatch): + # Empty/none headers → the kwarg isn't passed at all (kernel + # keeps its defaults). + captured = self._open_capturing(monkeypatch, []) + assert "http_headers" not in captured + + def test_authorization_and_org_id_dropped_before_forwarding(self, monkeypatch): + # The connector must NOT forward Authorization / x-databricks-org-id + # to the kernel — the kernel manages both (and warns per request + # if it sees them). Double-walls the kernel's own skip. + headers = [ + ("Authorization", "Bearer should-not-forward"), + ("X-Databricks-Org-Id", "12345"), + ("User-Agent", "PyDatabricksSqlConnector/4.0 (e)"), + ("X-Keep", "yes"), + ] + captured = self._open_capturing(monkeypatch, headers) + fwd = captured.get("http_headers") + names = {n.lower() for n, _ in fwd} + assert "authorization" not in names + assert "x-databricks-org-id" not in names + # Non-reserved headers (incl. User-Agent) still forwarded. + assert ("User-Agent", "PyDatabricksSqlConnector/4.0 (e)") in fwd + assert ("X-Keep", "yes") in fwd + + def test_only_reserved_headers_omits_kwarg(self, monkeypatch): + # If the only headers are reserved ones, nothing is forwarded + # and the kwarg is omitted entirely. + captured = self._open_capturing( + monkeypatch, [("Authorization", "Bearer x"), ("x-databricks-org-id", "1")] + ) + assert "http_headers" not in captured diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index c7d8878b5..27d2b96c7 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -476,3 +476,52 @@ def test_retry_kwargs_threaded_into_kernel_client(self): assert opts["retry_stop_after_attempts_duration"] == 600.0 finally: conn.close() + + +class TestKernelUserAgentForwarding: + """user_agent_entry must reach the kernel on the use_kernel path — + session.py folds it into the composed User-Agent and includes it in + all_headers, which is passed to the kernel client as http_headers. + Guards against a regression where session.py stops folding it under + use_kernel=True (which would silently drop partner attribution).""" + + PACKAGE = "databricks.sql" + + def test_user_agent_entry_reaches_kernel_client_http_headers(self): + import sys + import types + + pytest.importorskip( + "pyarrow", reason="kernel client module imports pyarrow at load" + ) + + fake = types.ModuleType("databricks_sql_kernel") + fake.KernelError = type("KernelError", (Exception,), {}) + fake.Session = MagicMock() + + with patch.dict(sys.modules, {"databricks_sql_kernel": fake}), patch( + "databricks.sql.backend.kernel.client.KernelDatabricksClient" + ) as mock_kernel_client, patch( + "%s.session.get_python_sql_connector_auth_provider" % self.PACKAGE + ): + instance = mock_kernel_client.return_value + instance.open_session.return_value = SessionId( + BackendType.SEA, "sess-id", None + ) + + conn = databricks.sql.connect( + server_hostname="foo", + http_path="/sql/1.0/warehouses/abc", + use_kernel=True, + access_token="dapi-xyz", + enable_telemetry=False, + user_agent_entry="my-partner-app", + ) + try: + _, kwargs = mock_kernel_client.call_args + # http_headers carries a User-Agent that embeds the entry. + headers = dict(kwargs["http_headers"]) + ua = headers.get("User-Agent", "") + assert "my-partner-app" in ua, f"UA was {ua!r}" + finally: + conn.close()