設計、テストときて、最後の実装です。実際には一つテストを書いては実装して、適度にリファクタリングしてを繰り返しているのですけど、それをいちいち書いていたら大変なので。全部を一気に説明すると大変なので、部分ごとに説明します。
クラスファイルの置き場
今回は MonoDouble というクラスなので、パスが通った場所に「@MonoDouble」というフォルダを作ります。この中に、MonoDouble.m というファイルや各メソッドごとのファイルを記述します。私の場合は、全部 MonoDouble.m にすべてのメソッドを実装しています。
メソッドの書き方
呼び出し的にはメソッド呼び出しですが、内部は関数として定義します。自分自身(self)を関数にわからせる必要があるので、メソッド呼び出しをすると自分自身が関数の第一引数に設定されます(私が C でオブジェクト指向もどきをやっていた時も似たような手段を使っていました)。また、クラスメソッドも静的メソッドとして定義ができます。
クラスを handle の継承で作成すると、通常のオブジェクト指向言語と同様に参照渡しになります。そのため、内部で self を変更することで、破壊的メソッドを作成できます。
クラスの属性(プロパティ)の設定
クラスの属性は Property で設定できます。Accessor は自分で記述することもできますが、何も設定しなければ、Public Accessor を用意してくれるようです。外からいじられたくなければちゃんと設定しなければなりません。
以上を踏まえて作った実装を示します。ポイントだけ日本語で解説してます.
% handle 型を継承してクラスを定義 (参照渡しになる) classdef MonoDouble < handle % プロパティ一覧 properties buffer % array buffer bits % the number of bits blockWidth % width of block blockHeight % height of block nowx % x of now position (from 0 to width-1) nowy % y of now position (from 0 to height-1) block % reference for sub MonoDouble end % インスタンスメソッド methods %=== Constructor methods % Constructor for 8bits array -> a = MonoDouble(array) % Constructor for n-bits array -> a = MonoDouble(array, n) function self = MonoDouble(m, b) % 引数が 1 つなら自動的に 8bpp if nargin == 1 b = 8; end self.buffer = double(m); self.bits = b; end %=== methods for properties function out = height(self); out = size(self.buffer, 1); end function out = width(self); out = size(self.buffer, 2); end %=== destructive methods (The last character of the method name is '_'.) function self = abs_(self); self.buffer = abs(self.buffer); end function self = add_(self, other); self = self.addMat_(other.buffer); end function self = addMat_(self, mat); self.buffer = self.buffer + mat; end function self = angle_(self); self.buffer = angle(self.buffer); end function self = clip_(self) self.round_; m = 2^self.bits - 1; self.buffer(self.buffer > m) = m; self.buffer(self.buffer < 0) = 0; end function self = dct_(self); self.buffer = dct2(self.buffer); end function self = fft_(self); self.buffer = fft2(self.buffer); end function self = fftshift_(self); self.buffer = fftshift(self.buffer); end function self = idct_(self); self.buffer = idct2(self.buffer); end function self = ifft_(self); self.buffer = ifft2(self.buffer); end function self = mul_(self, other); self = self.mulMat_(other.buffer); end function self = mulMat_(self, mat); self.buffer = self.buffer .* mat; end function self = rand_(self); self.buffer = rand(size(self.buffer)); end function self = round_(self); self.buffer = round(self.buffer); end function self = sign_(self); self.buffer = sign(self.buffer); end %=== nondestructive methods (call correspond destructive method inside) function out = abs(self); out = self.copy.abs_; end function out = add(self, other); out = self.copy.add_(other); end function out = addMat(self, mat); out = self.copy.addMat_(mat); end function out = angle(self); out = self.copy.angle_; end function out = clip(self); out = self.copy.clip_; end function out = dct(self); out = self.copy.dct_; end function out = fft(self); out = self.copy.fft_; end function out = fftshift(self); out = self.copy.fftshift_; end function out = idct(self); out = self.copy.idct_; end function out = ifft(self); out = self.copy.ifft_; end function out = mul(self, other); out = self.copy.mul_(other); end function out = mulMat(self, mat); out = self.copy.mulMat_(mat); end function out = rand(self); out = self.copy.rand_; end function out = round(self); out = self.copy.round_; end function out = sign(self); out = self.copy.sign_; end %=== other function out = copy(self); out = MonoDouble(self.buffer, self.bits); end function out = zeros(self); out = MonoDouble(zeros(size(self.buffer)), self.bits); end function disp(self); disp(self.bits); disp(self.buffer); end % display MonoDouble function PSNR = calcPSNR(self, other) d = floor(self.buffer) - floor(other.buffer); dd = d .* d; [y, x] = size(self.buffer); max = 2^self.bits - 1; PSNR = 10 * log10(max * max / sum(sum(dd)) / x / y)); end function printPSNR(self, other, str) if nargin == 2 str = ''; end display(strcat(num2str(self.calcPSNR(other)), ' :', str)) end function imwrite(self, name) if self.bits <= 8 imwrite(uint8(self.buffer), name); elseif self.bits <= 16 imwrite(uint16(self.buffer), name); else imwrite(uint32(self.buffer), name); end end function imshow(self) if self.bits <= 8 imshow(uint8(self.buffer)); elseif self.bits <= 16 imshow(uint16(self.buffer)); else imshow(uint32(self.buffer)); end end function out = head(self, r, c) if nargin == 1 r = 10; c = 10; elseif nargin == 2 c = 10; end if self.width < c c = self.width; end if self.height < r r = self.height; end out = self.buffer(1:r, 1:c); end %=== methods for iterator function self = setBlockSize_(self, by, bx) if by > self.height by = self.height; end if bx > self.width bx = self.width; end self.blockWidth = bx; self.blockHeight = by; self.nowx = 0; self.nowy = 0; self.block = MonoDouble(zeros(self.blockHeight, self.blockWidth)); end function iterateStart_(self); self.setBlockSize_(1, 1); end function out = value(self) if self.nowx == -1 out = nan; else out = self.buffer(self.nowy+1, self.nowx+1); end end function out = getBlock_(self) if self.nowx == -1 out = []; else w = self.width - self.nowx; if w > self.blockWidth w = self.blockWidth; end h = self.height - self.nowy; if h > self.blockHeight h = self.blockHeight; end self.block.buffer = self.buffer(self.nowy+1:self.nowy+h, self.nowx+1:self.nowx+w); out = self.block; end end function setBlock_(self, block) if self.nowx == -1 disp('already finished!'); else w = self.width - self.nowx; if w > self.blockWidth w = self.blockWidth; end h = self.height - self.nowy; if h > self.blockHeight h = self.blockHeight; end self.buffer(self.nowy+1:self.nowy+h, self.nowx+1:self.nowx+w) = block.buffer; end end function out = next_(self) if self.nowx == -1 out = 0; else out = 1; self.nowx = self.nowx + self.blockWidth; if self.nowx >= self.width self.nowx = 0; self.nowy = self.nowy + self.blockHeight; if self.nowy >= self.height self.nowx = -1; self.nowy = -1; out = 0; end end end end end % クラスメソッド methods(Static) function out = imread(name, b) if nargin == 1 b = 8; end out = MonoDouble(double(imread(name)), b); end end end
例外処理を一部サボっているけど,使うのは自分だけだからなんとかなるかな.時間があればこれからリファクタリングしていきます.