Skip to content

Commit

Permalink
Add progress for TS model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
DanilRostov committed Dec 16, 2023
1 parent 2f67b5e commit 594daa8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
14 changes: 11 additions & 3 deletions src/engine/Graphics2d.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Graphics2d extends React.Component {
constructor(props) {
super(props);

this.store = props;
this.m_mount = React.createRef();

this.onMouseDown = this.onMouseDown.bind(this);
Expand Down Expand Up @@ -527,7 +528,7 @@ class Graphics2d extends React.Component {
if (isSegm) {
const w = this.m_toolPick.m_wScreen;
const h = this.m_toolPick.m_hScreen;
this.segm2d.render(ctx, w, h, this.imgData);
this.segm2d.renderImage(ctx, w, h, this.imgData);
} else {
createImageBitmap(this.imgData)
.then((imageBitmap) => {
Expand Down Expand Up @@ -692,6 +693,12 @@ class Graphics2d extends React.Component {
startY: evt.clientY,
});
}

if (this.m_isSegmented && this.segm2d.model) {
// We do not need update segmented image (with model)
// on mouse move event to performance issues.
return;
}
store.graphics2d.forceUpdate();
}

Expand Down Expand Up @@ -767,13 +774,14 @@ class Graphics2d extends React.Component {
* Invoke forced rendering, after some tool visual changes
*/
forceUpdate(volIndex) {
// console.log('forceUpdate ...');
console.log('forceUpdate ...');
this.prepareImageForRender(volIndex);
// this.forceRender();
if (this.m_isSegmented) {
// need to draw segmented image
if (this.segm2d.model !== null) {
// we have loaded model: applt it to image
// we have loaded model: apply it to image
// TODO update image only on some specific events: zoom, explore
this.segm2d.startApplyImage();
}
} else {
Expand Down
29 changes: 19 additions & 10 deletions src/engine/Segm2d.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
// ********************************************************

import * as tf from '@tensorflow/tfjs';
const PATH_MODEL = 'https://lugachai.ru/med3web/tfjs/model.json';
import StoreActionType from '../store/ActionTypes';

const BRAIN_MODEl = 'https://daentjnvnffrh.cloudfront.net/models/brain/model.json';

// ********************************************************
// Const
Expand Down Expand Up @@ -45,9 +47,11 @@ const NUM_CLASSES = 96;
// ********************************************************

class Segm2d {
constructor(objGraphics2d) {
constructor(props) {
this.stage = STAGE_MODEL_NOT_LOADED;
this.objGraphics2d = objGraphics2d;
this.objGraphics2d = props;
this.store = props.store;

this.model = null;
this.tensorIndices = null;
this.imgData = null;
Expand Down Expand Up @@ -116,14 +120,14 @@ class Segm2d {
} // for (y)
}

//
// Load model
async onLoadModel() {
this.stage = STAGE_MODEL_IS_LOADING;
this.pixels = null;

console.log('Loading tfjs model...');
const modelLoaded = await tf.loadLayersModel(PATH_MODEL, { strict: false });
this.store.dispatch({ type: StoreActionType.SET_PROGRESS_INFO, titleProgressBar: 'Loading tfjs model...' });
const modelLoaded = await tf.loadLayersModel(BRAIN_MODEl, { strict: false, onProgress: this.onTFLoadProgress.bind(this) });

this.model = modelLoaded;
this.stage = STAGE_MODEL_READY;
Expand All @@ -135,9 +139,6 @@ class Segm2d {
}

async startApplyImage() {
if (this.stage === STAGE_SEGMENTATION_READY) {
return;
}
this.stage = STAGE_IMAGE_PROCESSED;
console.log('Start apply segm to image ...');

Expand Down Expand Up @@ -280,7 +281,7 @@ class Segm2d {
this.srcImageData = imgData;
}

render(ctx, w, h, imgData) {
renderImage(ctx, w, h, imgData) {
this.srcImageData = imgData;
this.wSrc = w;
this.hSrc = h;
Expand All @@ -291,7 +292,7 @@ class Segm2d {
console.log('Segm2d render. stage = ' + strMessage);

// load model
if (this.model === null) {
if (this.model === null && this.stage === STAGE_MODEL_NOT_LOADED) {
this.onLoadModel();
} else {
// change slider or similar: need to rebuild segm for the new source image
Expand Down Expand Up @@ -327,6 +328,14 @@ class Segm2d {
const y = h / 2;
ctx.fillText(strMsgPrint, x, y);
}

onTFLoadProgress(progress) {
this.store.dispatch({ type: StoreActionType.SET_PROGRESS, progress });
if (progress === 1) {
this.store.dispatch({ type: StoreActionType.SET_PROGRESS, progress: 0 });
this.store.dispatch({ type: StoreActionType.SET_PROGRESS_INFO, titleProgressBar: null });
}
}
}

function drawRoundedRect(ctx, x, y, width, height, borderRadius) {
Expand Down

0 comments on commit 594daa8

Please sign in to comment.