Skip to content

Commit

Permalink
add support for a daskexecutor to run against systems operators (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 committed Jul 1, 2023
1 parent 01af5d3 commit 56b3adc
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/unit/systems/ops/tf/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,23 @@ def test_workflow_tf_python_wrapper(tmpdir, dataset, engine, python):
str(tmpdir), request_schema, df, ["predictions"], ensemble_config.name
)
assert len(response["predictions"]) == df.shape[0]


@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
@pytest.mark.parametrize("engine", ["parquet"])
@pytest.mark.parametrize("python", [False, True])
def test_workflow_tf_python_nvt_chain(tmpdir, dataset, engine, python):
# Create a Workflow
workflow_ops = ["name-cat", "name-string"] >> wf_ops.Categorify(cat_cache="host")
workflow = Workflow(workflow_ops)
workflow.fit(dataset)

embedding_shapes = wf_ops.get_embedding_sizes(workflow)

model = create_tf_model(["name-cat", "name-string"], [], embedding_shapes)

df = dataset.to_ddf().compute()[["name-string", "name-cat"]]
response = Workflow(workflow_ops >> PredictTensorflow(model)).fit_transform(dataset)
response = response.to_ddf().compute().reset_index(drop=True)
assert "predictions" in response.columns
assert response.shape[0] == df.shape[0]

0 comments on commit 56b3adc

Please sign in to comment.