123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- "use strict";
- // 常量定义
- const NUM_CLASSES = 8;
- const IMAGE_SIZE = 227; // MobileNet V1 通常需要 227x227 或 224x224,请根据实际模型调整
- const TOPK = 3;
- class Main {
- // --- 只有一个构造函数 ---
- constructor() {
- this.infoTexts = [];
- this.trainingClass = -1; // 用于标记哪个按钮被点击以进行训练
- this.currentImageElement = null; // 当前用于预览和处理的图片元素
- this.isAppReady = false;
- this.trainingButtons = []; // *** 在这里初始化 trainingButtons ***
- // 获取 DOM 元素
- this.fileInput = document.getElementById("file-input");
- this.imagePreview = document.getElementById("image-preview");
- this.imagePlaceholder = document.getElementById("preview-placeholder");
- this.classifyButton = document.getElementById("classify-button");
- this.controlsContainer = document.getElementById("controls-container");
- this.statusDiv = document.getElementById("status");
- // 使用 CDN 加载的全局变量
- this.tf = tf; // 来自 @tensorflow/tfjs
- this.knnClassifier = knnClassifier; // 来自 @tensorflow-models/knn-classifier
- this.mobilenet = mobilenet; // 来自 @tensorflow-models/mobilenet
- // 初始化页面和模型
- this.bindPage();
- // 创建训练按钮和信息文本 (现在可以安全调用了)
- this.createTrainingControls();
- // 绑定事件监听器
- this.bindEventListeners();
- }
- // --- 构造函数结束 ---
- updateStatus(message) {
- this.statusDiv.innerText = message;
- console.log(message); // 同时在控制台输出
- }
- createTrainingControls() {
- for (let i = 0; i < NUM_CLASSES; i++) {
- const div = document.createElement("div");
- div.classList.add("class-controls"); // 添加 CSS 类
- this.controlsContainer.appendChild(div);
- // 创建训练按钮
- const button = document.createElement("button");
- button.innerText = `训练物品 ${i}`;
- button.disabled = true; // 初始禁用,直到模型加载完成
- div.appendChild(button);
- // 点击按钮时,使用当前预览的图片进行训练
- button.addEventListener("click", async () => {
- if (
- this.imagePreview.style.display === "none" ||
- !this.imagePreview.src ||
- this.imagePreview.src.endsWith("#")
- ) {
- alert("请先选择一张图片再进行训练!");
- return;
- }
- this.updateStatus(`正在为物品 ${i} 添加样本...`);
- await this.addImageForTraining(i);
- this.updateStatus(`物品 ${i} 的样本已添加。`);
- });
- // 创建信息文本
- const infoText = document.createElement("span");
- infoText.innerText = " 0 个学习样例";
- div.appendChild(infoText);
- this.infoTexts.push(infoText);
- this.trainingButtons.push(button); // 保存按钮引用以便启用/禁用
- }
- }
- bindEventListeners() {
- // 文件选择事件
- this.fileInput.addEventListener("change", (event) => {
- const file = event.target.files[0];
- if (file && file.type.startsWith("image/")) {
- const reader = new FileReader();
- reader.onload = (e) => {
- this.imagePreview.src = e.target.result;
- this.imagePreview.style.display = "block";
- this.imagePlaceholder.style.display = "none"; // 隐藏占位符
- this.classifyButton.disabled = !this.isAppReady; // 启用识别按钮(如果应用已就绪)
- // 启用训练按钮
- this.trainingButtons.forEach(
- (btn) => (btn.disabled = !this.isAppReady)
- );
- this.updateStatus("图片已加载,可以进行识别或训练。");
- };
- reader.readAsDataURL(file);
- } else {
- // 清除预览并禁用按钮
- this.imagePreview.src = "#";
- this.imagePreview.style.display = "none";
- this.imagePlaceholder.style.display = "block"; // 显示占位符
- this.classifyButton.disabled = true;
- this.trainingButtons.forEach((btn) => (btn.disabled = true));
- if (file) {
- alert("请选择一个图片文件!");
- this.updateStatus("请选择一个图片文件。");
- } else {
- this.updateStatus("未选择图片。");
- }
- }
- });
- // 识别按钮点击事件
- this.classifyButton.addEventListener("click", async () => {
- if (
- this.imagePreview.style.display === "none" ||
- !this.imagePreview.src ||
- this.imagePreview.src.endsWith("#")
- ) {
- alert("请先选择一张图片再进行识别!");
- return;
- }
- await this.classifyCurrentImage();
- });
- }
- async bindPage() {
- try {
- this.updateStatus("正在初始化应用...");
- // 创建 KNN 分类器实例
- this.knn = this.knnClassifier.create();
- this.updateStatus("KNN 分类器已创建");
- // 加载 MobileNet 模型
- this.updateStatus("正在加载 MobileNet 模型...");
- try {
- this.mobilenetModel = await this.mobilenet.load({
- version: 1,
- alpha: 1.0,
- });
- console.log("MobileNet V1 模型加载完成");
- } catch (e) {
- console.warn("MobileNet V1 加载失败,尝试加载 V2:", e);
- this.mobilenetModel = await this.mobilenet.load({
- version: 2,
- alpha: 1.0,
- });
- console.log("MobileNet V2 模型加载完成");
- }
- this.updateStatus("模型加载完成,应用准备就绪。");
- this.isAppReady = true;
- // 如果用户在模型加载前就选了图片,现在启用按钮
- if (this.imagePreview.style.display !== "none") {
- this.classifyButton.disabled = false;
- this.trainingButtons.forEach((btn) => (btn.disabled = false));
- }
- } catch (error) {
- console.error("模型加载或初始化失败:", error);
- this.updateStatus("错误:模型加载失败。请检查网络或控制台。");
- alert("模型加载失败,请检查网络连接或控制台错误。");
- }
- }
- // 提取特征的辅助函数
- async getFeaturesFromImage(imageElement) {
- // *** 使用 tf.tidy 来管理此函数内部的张量 ***
- return tf.tidy(() => {
- let imageTensor = this.tf.browser.fromPixels(imageElement);
- // 调整大小并标准化
- let expectedSize = IMAGE_SIZE;
- try {
- if (
- this.mobilenetModel &&
- this.mobilenetModel.inputs &&
- this.mobilenetModel.inputs[0].shape
- ) {
- const shape = this.mobilenetModel.inputs[0].shape;
- if (
- shape &&
- shape.length >= 3 &&
- shape[1] != null &&
- shape[2] != null &&
- shape[1] === shape[2]
- ) {
- expectedSize = shape[1];
- console.log(`检测到模型输入尺寸: ${expectedSize}x${expectedSize}`);
- } else {
- console.warn(
- `无法从模型输入形状 ${JSON.stringify(
- shape
- )} 推断尺寸,使用默认值 ${IMAGE_SIZE}。`
- );
- }
- }
- } catch (e) {
- console.warn(
- `获取模型输入尺寸时出错,使用默认值 ${IMAGE_SIZE}。错误: ${e}`
- );
- }
- let resizedTensor = tf.image
- .resizeBilinear(imageTensor, [expectedSize, expectedSize], true)
- .expandDims(0);
- let normalizedTensor = resizedTensor.toFloat().div(127.5).sub(1.0);
- // 特征提取
- let logits = null;
- try {
- if (typeof this.mobilenetModel.infer === "function") {
- logits = this.mobilenetModel.infer(normalizedTensor, true);
- } else if (typeof this.mobilenetModel.predict === "function") {
- logits = this.mobilenetModel.predict(normalizedTensor);
- if (Array.isArray(logits)) {
- logits = logits[0];
- }
- if (logits.rank > 2) {
- const batchSize = logits.shape[0];
- logits = logits.reshape([batchSize, -1]);
- }
- } else {
- throw new Error("无法识别的模型接口");
- }
- } catch (error) {
- console.error("MobileNet 推理失败: ", error);
- return null; // 返回 null 表示失败
- }
- // *** tidy 会自动清理 imageTensor, resizedTensor, normalizedTensor ***
- // *** 但 logits 因为被返回,所以不会被这个 tidy 清理 ***
- return logits;
- }); // *** tidy 结束 ***
- }
- // 添加训练样本 (修改后)
- async addImageForTraining(classIndex) {
- if (!this.isAppReady) {
- alert("模型尚未准备好,请稍候。");
- return;
- }
- if (
- !this.imagePreview ||
- this.imagePreview.style.display === "none" ||
- !this.imagePreview.src ||
- this.imagePreview.src.endsWith("#")
- ) {
- alert("请先选择一张图片。");
- return;
- }
- // 1. 获取特征 (内部使用 tidy)
- const logits = await this.getFeaturesFromImage(this.imagePreview);
- if (logits) {
- // 2. 执行异步 KNN 操作 (不在 tidy 内部)
- this.knn.addExample(logits, classIndex);
- this.updateExampleCountUI(classIndex);
- // 3. **手动释放不再需要的 logits** (因为它被 getFeaturesFromImage 的 tidy 返回了)
- logits.dispose();
- } else {
- this.updateStatus(`错误:无法从图片提取特征用于训练物品 ${classIndex}`);
- }
- }
- // 分类当前图片 (修改后)
- async classifyCurrentImage() {
- if (!this.isAppReady) {
- alert("模型尚未准备好,请稍候。");
- return;
- }
- if (
- !this.imagePreview ||
- this.imagePreview.style.display === "none" ||
- !this.imagePreview.src ||
- this.imagePreview.src.endsWith("#")
- ) {
- alert("请先选择一张图片。");
- return;
- }
- this.updateStatus("正在识别图片...");
- const numClasses = this.knn.getNumClasses();
- if (numClasses === 0) {
- this.updateStatus("错误:您还没有训练任何物品类别!");
- alert("请先训练至少一个物品类别再进行识别。");
- return;
- }
- // 1. 获取特征 (内部使用 tidy)
- const logits = await this.getFeaturesFromImage(this.imagePreview);
- if (logits) {
- // 2. 执行异步 KNN 操作 (不在 tidy 内部)
- const res = await this.knn.predictClass(logits, TOPK);
- this.updatePredictionUI(res);
- this.updateStatus(
- `识别完成。最可能的类别: 物品 ${res.classIndex} (${(
- res.confidences[res.classIndex] * 100
- ).toFixed(1)}%)`
- );
- // 3. **手动释放不再需要的 logits**
- logits.dispose();
- } else {
- this.updateStatus("错误:无法从图片提取特征进行识别。");
- }
- }
- // 更新特定类别的样本计数UI (保持不变)
- updateExampleCountUI(classIndex) {
- const exampleCount = this.knn.getClassExampleCount();
- if (this.infoTexts[classIndex]) {
- const count = exampleCount[classIndex] || 0;
- const currentText = this.infoTexts[classIndex].innerText;
- const parts = currentText.split(" - ");
- this.infoTexts[classIndex].innerText =
- ` ${count} 个学习样例` + (parts.length > 1 ? ` - ${parts[1]}` : "");
- this.infoTexts[classIndex].style.fontWeight = "normal";
- }
- }
- // 更新所有类别的预测UI
- updatePredictionUI(predictionResult) {
- const exampleCount = this.knn.getClassExampleCount();
- for (let i = 0; i < NUM_CLASSES; i++) {
- const infoText = this.infoTexts[i];
- if (infoText) {
- // 确保元素存在
- const count = exampleCount[i] || 0;
- // 使用 ?. 确保 confidences 存在且属性存在
- const confidence =
- predictionResult?.confidences?.[i] != null
- ? (predictionResult.confidences[i] * 100).toFixed(1)
- : "0.0"; // 默认为 0%
- infoText.innerText = ` ${count} 个学习样例 - ${confidence}%`;
- // 加粗预测结果
- if (predictionResult.classIndex === i) {
- infoText.style.fontWeight = "bold";
- } else {
- infoText.style.fontWeight = "normal";
- }
- }
- }
- }
- }
- // 页面加载完成后启动应用
- window.addEventListener("load", () => new Main());
|