diff --git a/tests/unit/systems/ops/tf/test_ensemble.py b/tests/unit/systems/ops/tf/test_ensemble.py index c6c83dca3..cccb02998 100644 --- a/tests/unit/systems/ops/tf/test_ensemble.py +++ b/tests/unit/systems/ops/tf/test_ensemble.py @@ -223,3 +223,23 @@ def test_workflow_tf_python_wrapper(tmpdir, dataset, engine, python): str(tmpdir), request_schema, df, ["predictions"], ensemble_config.name ) assert len(response["predictions"]) == df.shape[0] + + +@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found") +@pytest.mark.parametrize("engine", ["parquet"]) +@pytest.mark.parametrize("python", [False, True]) +def test_workflow_tf_python_nvt_chain(tmpdir, dataset, engine, python): + # Create a Workflow + workflow_ops = ["name-cat", "name-string"] >> wf_ops.Categorify(cat_cache="host") + workflow = Workflow(workflow_ops) + workflow.fit(dataset) + + embedding_shapes = wf_ops.get_embedding_sizes(workflow) + + model = create_tf_model(["name-cat", "name-string"], [], embedding_shapes) + + df = dataset.to_ddf().compute()[["name-string", "name-cat"]] + response = Workflow(workflow_ops >> PredictTensorflow(model)).fit_transform(dataset) + response = response.to_ddf().compute().reset_index(drop=True) + assert "predictions" in response.columns + assert response.shape[0] == df.shape[0]