fortran Code
- module strassen
- implicit none
- type t
- real :: value
- end type t
- interface operator(.mal.)
- module procedure strassen_matprod
- end interface operator(.mal.)
- interface operator(.x.)
- module procedure simple_matprod
- end interface operator(.x.)
- interface operator(*)
- module procedure mul_t
- end interface operator(*)
- contains
- function mul_t(x,y) result(erg)
- type(t),intent(in) :: x,y
- type(t) :: erg
- erg%value = x%value * y%value
- end function mul_t
- recursive function strassen_matprod(a,b) result(erg)
- real,dimension(:,:),intent(in) :: a,b
- real,dimension(size(a,1),size(a,1)) :: erg
- real,dimension(size(a,1)/2,size(a,1)/2) :: m1,m2,m3,m4,m5,m6,m7, &
- & a11,a12,a21,a22,b11,b12,b21,b22
- integer :: s
- s = size(a,1)/2
- if(size(a,1) <= 64) then
- erg = simple_matprod4(a,b)
- else
- a11 = a(1:s,1:s)
- a12 = a(1:s,s+1:2*s)
- a21 = a(s+1:2*s,1:s)
- a22 = a(s+1:2*s,s+1:2*s)
- b11 = b(1:s,1:s)
- b12 = b(1:s,s+1:2*s)
- b21 = b(s+1:2*s,1:s)
- b22 = b(s+1:2*s,s+1:2*s)
- m1 = (a12-a22) .mal. (b21+b22)
- m2 = (a11+a22) .mal. (b11+b22)
- m3 = (a11-a21) .mal. (b11+b12)
- m4 = (a11+a12) .mal. b22
- m5 = a11 .mal. (b12-b22)
- m6 = a22 .mal. (b21-b11)
- m7 = (a21+a22) .mal. b11
- erg(1:s,1:s) = m1+m2-m4+m6
- erg(1:s,s+1:2*s) = m4+m5
- erg(s+1:2*s,1:s) = m6+m7
- erg(s+1:2*s,s+1:2*s) = m2-m3+m5-m7
- end if
- end function strassen_matprod
- function simple_matprod(a,b) result(erg)
- real,dimension(:,:),intent(in) :: a,b
- real,dimension(size(a,1),size(a,2)) :: erg
- real :: summe
- integer :: i,j,k
- erg=0
- do i=1,size(a,1)
- do j=1,size(a,1)
- do k=1,size(a,1)
- erg(i,j) = erg(i,j) + a(i,k)*b(k,j)
- end do
- end do
- end do
- end function simple_matprod
- function simple_matprod4(a,b) result(erg)
- real,dimension(:,:),intent(in) :: a,b
- real,dimension(size(a,1),size(a,2)) :: erg
- integer :: i,j,k,n
- erg=0
- n = size(a,1)
- do j=1,n
- do k=1,n
- do i=1,n
- erg(i,j) = erg(i,j) + a(i,k)*b(k,j)
- end do
- end do
- end do
- end function simple_matprod4
- end module strassen
- program test
- use strassen
- implicit none
- integer :: k,n,i,j,fehler
- real,dimension(:,:),allocatable :: a,b,c1,c2,c3
- real :: z
- open(30,file='daten.dat',status='old')
- write(*,*) 'k:'
- read(*,*) k
- write(30,*) k
- n = 2**k
- call random_seed()
- do i=1,2*n*n
- call random_number(z)
- write(30,*) int(z*20)
- end do
- rewind(30)
- read(30,*) k
- n = 2**k
- allocate(a(n,n),b(n,n),c1(n,n),c2(n,n),c3(n,n))
- aussen: do i=1,n
- do j=1,n
- read(30,*,iostat=fehler) a(i,j)
- if(fehler /= 0) exit aussen
- end do
- end do aussen
- aussen2: do i=1,n
- do j=1,n
- read(30,*,iostat=fehler) b(i,j)
- if(fehler /= 0) exit aussen2
- end do
- end do aussen2
- write(*,*) dtime()
- c1 = a .mal. b
- write(*,*) 'strassen:',dtime()
- c2 = simple_matprod4(a,b)
- write(*,*) 'simple4:',dtime()
- c3 = matmul(a,b)
- write(*,*) 'matmul:',dtime()
- write(*,*) all(c1 == c2 .and. c2 == c3)
- end program test
In der Funktion simple_matprod wurde auch noch eine Optimierung vorgenommen. Durch geschicktes Vertauschen der Schleifenindizes kann der Algorithmus beschleunigt werden. Dies liegt unter anderem daran, wie Fortran die Matrizen im Speicher ablegt. durch vertauschen der Indizes wird erzeugt, dass bei der Addition zweiter Elemente aus den Matrizen im Speicher nebeneinander stehende Werte gelesen werden können. Da beim Lesen aus dem Speicher nicht nur jeweils eine Zahl gelesen wird, sondern ein ganzer Block von vielleicht 16 oder 64 Zahlen, reicht für die Addition von 16 bzw. 64 Elementen ein Lesezugriff aus. Durch diese Einsparung erhöht sich die Geschwindigkeit einer einfachen Matrix-Multiplikation enorm. In anderen Progrmmiersprachen (z.B. C) werden Matrizen andersherum im Speicher abgelegt, so dass dort eine entsprechend andere Optimierung vorgenommen werden muss.
Eine weitere Implementierung des Algorithmus habe ich in Scilab geschrieben.