diff --git a/src/engine/Graphics2d.js b/src/engine/Graphics2d.js index 1104c73..d2eb1a2 100644 --- a/src/engine/Graphics2d.js +++ b/src/engine/Graphics2d.js @@ -27,6 +27,7 @@ class Graphics2d extends React.Component { constructor(props) { super(props); + this.store = props; this.m_mount = React.createRef(); this.onMouseDown = this.onMouseDown.bind(this); @@ -527,7 +528,7 @@ class Graphics2d extends React.Component { if (isSegm) { const w = this.m_toolPick.m_wScreen; const h = this.m_toolPick.m_hScreen; - this.segm2d.render(ctx, w, h, this.imgData); + this.segm2d.renderImage(ctx, w, h, this.imgData); } else { createImageBitmap(this.imgData) .then((imageBitmap) => { @@ -692,6 +693,12 @@ class Graphics2d extends React.Component { startY: evt.clientY, }); } + + if (this.m_isSegmented && this.segm2d.model) { + // We do not need update segmented image (with model) + // on mouse move event to performance issues. + return; + } store.graphics2d.forceUpdate(); } @@ -767,13 +774,14 @@ class Graphics2d extends React.Component { * Invoke forced rendering, after some tool visual changes */ forceUpdate(volIndex) { - // console.log('forceUpdate ...'); + console.log('forceUpdate ...'); this.prepareImageForRender(volIndex); // this.forceRender(); if (this.m_isSegmented) { // need to draw segmented image if (this.segm2d.model !== null) { - // we have loaded model: applt it to image + // we have loaded model: apply it to image + // TODO update image only on some specific events: zoom, explore this.segm2d.startApplyImage(); } } else { diff --git a/src/engine/Segm2d.js b/src/engine/Segm2d.js index 0f73ecc..76e35c6 100644 --- a/src/engine/Segm2d.js +++ b/src/engine/Segm2d.js @@ -14,7 +14,9 @@ // ******************************************************** import * as tf from '@tensorflow/tfjs'; -const PATH_MODEL = 'https://lugachai.ru/med3web/tfjs/model.json'; +import StoreActionType from '../store/ActionTypes'; + +const BRAIN_MODEl = 'https://daentjnvnffrh.cloudfront.net/models/brain/model.json'; // ******************************************************** // Const @@ -45,9 +47,11 @@ const NUM_CLASSES = 96; // ******************************************************** class Segm2d { - constructor(objGraphics2d) { + constructor(props) { this.stage = STAGE_MODEL_NOT_LOADED; - this.objGraphics2d = objGraphics2d; + this.objGraphics2d = props; + this.store = props.store; + this.model = null; this.tensorIndices = null; this.imgData = null; @@ -116,14 +120,14 @@ class Segm2d { } // for (y) } - // // Load model async onLoadModel() { this.stage = STAGE_MODEL_IS_LOADING; this.pixels = null; console.log('Loading tfjs model...'); - const modelLoaded = await tf.loadLayersModel(PATH_MODEL, { strict: false }); + this.store.dispatch({ type: StoreActionType.SET_PROGRESS_INFO, titleProgressBar: 'Loading tfjs model...' }); + const modelLoaded = await tf.loadLayersModel(BRAIN_MODEl, { strict: false, onProgress: this.onTFLoadProgress.bind(this) }); this.model = modelLoaded; this.stage = STAGE_MODEL_READY; @@ -135,9 +139,6 @@ class Segm2d { } async startApplyImage() { - if (this.stage === STAGE_SEGMENTATION_READY) { - return; - } this.stage = STAGE_IMAGE_PROCESSED; console.log('Start apply segm to image ...'); @@ -280,7 +281,7 @@ class Segm2d { this.srcImageData = imgData; } - render(ctx, w, h, imgData) { + renderImage(ctx, w, h, imgData) { this.srcImageData = imgData; this.wSrc = w; this.hSrc = h; @@ -291,7 +292,7 @@ class Segm2d { console.log('Segm2d render. stage = ' + strMessage); // load model - if (this.model === null) { + if (this.model === null && this.stage === STAGE_MODEL_NOT_LOADED) { this.onLoadModel(); } else { // change slider or similar: need to rebuild segm for the new source image @@ -327,6 +328,14 @@ class Segm2d { const y = h / 2; ctx.fillText(strMsgPrint, x, y); } + + onTFLoadProgress(progress) { + this.store.dispatch({ type: StoreActionType.SET_PROGRESS, progress }); + if (progress === 1) { + this.store.dispatch({ type: StoreActionType.SET_PROGRESS, progress: 0 }); + this.store.dispatch({ type: StoreActionType.SET_PROGRESS_INFO, titleProgressBar: null }); + } + } } function drawRoundedRect(ctx, x, y, width, height, borderRadius) {