fortran Code
- module strassen
- implicit none
- type t
- real :: value
- end type t
- interface operator(.mal.)
- module procedure strassen_matprod
- end interface operator(.mal.)
- interface operator(.x.)
- module procedure simple_matprod
- end interface operator(.x.)
- interface operator(*)
- module procedure mul_t
- end interface operator(*)
- contains
- function mul_t(x,y) result(erg)
- type(t),intent(in) :: x,y
- type(t) :: erg
- erg%value = x%value * y%value
- end function mul_t
- recursive function strassen_matprod(a,b) result(erg)
- real,dimension(:,:),intent(in) :: a,b
- real,dimension(size(a,1),size(a,1)) :: erg
- real,dimension(size(a,1)/2,size(a,1)/2) :: m1,m2,m3,m4,m5,m6,m7, &
- & a11,a12,a21,a22,b11,b12,b21,b22
- integer :: s
- s = size(a,1)/2
- if(size(a,1) <= 64) then
- erg = simple_matprod4(a,b)
- else
- a11 = a(1:s,1:s)
- a12 = a(1:s,s+1:2*s)
- a21 = a(s+1:2*s,1:s)
- a22 = a(s+1:2*s,s+1:2*s)
- b11 = b(1:s,1:s)
- b12 = b(1:s,s+1:2*s)
- b21 = b(s+1:2*s,1:s)
- b22 = b(s+1:2*s,s+1:2*s)
- m1 = (a12-a22) .mal. (b21+b22)
- m2 = (a11+a22) .mal. (b11+b22)
- m3 = (a11-a21) .mal. (b11+b12)
- m4 = (a11+a12) .mal. b22
- m5 = a11 .mal. (b12-b22)
- m6 = a22 .mal. (b21-b11)
- m7 = (a21+a22) .mal. b11
- erg(1:s,1:s) = m1+m2-m4+m6
- erg(1:s,s+1:2*s) = m4+m5
- erg(s+1:2*s,1:s) = m6+m7
- erg(s+1:2*s,s+1:2*s) = m2-m3+m5-m7
- end if
- end function strassen_matprod
- function simple_matprod(a,b) result(erg)
- real,dimension(:,:),intent(in) :: a,b
- real,dimension(size(a,1),size(a,2)) :: erg
- real :: summe
- integer :: i,j,k
- erg=0
- do i=1,size(a,1)
- do j=1,size(a,1)
- do k=1,size(a,1)
- erg(i,j) = erg(i,j) + a(i,k)*b(k,j)
- end do
- end do
- end do
- end function simple_matprod
- function simple_matprod4(a,b) result(erg)
- real,dimension(:,:),intent(in) :: a,b
- real,dimension(size(a,1),size(a,2)) :: erg
- integer :: i,j,k,n
- erg=0
- n = size(a,1)
- do j=1,n
- do k=1,n
- do i=1,n
- erg(i,j) = erg(i,j) + a(i,k)*b(k,j)
- end do
- end do
- end do
- end function simple_matprod4
- end module strassen
- program test
- use strassen
- implicit none
- integer :: k,n,i,j,fehler
- real,dimension(:,:),allocatable :: a,b,c1,c2,c3
- real :: z
- open(30,file='daten.dat',status='old')
- write(*,*) 'k:'
- read(*,*) k
- write(30,*) k
- n = 2**k
- call random_seed()
- do i=1,2*n*n
- call random_number(z)
- write(30,*) int(z*20)
- end do
- rewind(30)
- read(30,*) k
- n = 2**k
- allocate(a(n,n),b(n,n),c1(n,n),c2(n,n),c3(n,n))
- aussen: do i=1,n
- do j=1,n
- read(30,*,iostat=fehler) a(i,j)
- if(fehler /= 0) exit aussen
- end do
- end do aussen
- aussen2: do i=1,n
- do j=1,n
- read(30,*,iostat=fehler) b(i,j)
- if(fehler /= 0) exit aussen2
- end do
- end do aussen2
- write(*,*) dtime()
- c1 = a .mal. b
- write(*,*) 'strassen:',dtime()
- c2 = simple_matprod4(a,b)
- write(*,*) 'simple4:',dtime()
- c3 = matmul(a,b)
- write(*,*) 'matmul:',dtime()
- write(*,*) all(c1 == c2 .and. c2 == c3)
- end program test
Eine weitere Implementierung des Algorithmus habe ich in Scilab geschrieben.